Merge branch 'graph-storage-batch-query-frederikhendrix' into graph-storage-batch-query

This commit is contained in:
yangdx
2025-04-12 22:20:41 +08:00
4 changed files with 287 additions and 63 deletions

View File

@@ -11,3 +11,26 @@ services:
env_file: env_file:
- .env - .env
restart: unless-stopped restart: unless-stopped
neo4j:
image: neo4j:5.26.4-community
container_name: lightrag-server_neo4j-community
restart: always
ports:
- "7474:7474"
- "7687:7687"
environment:
- NEO4J_AUTH=${NEO4J_USERNAME}/${NEO4J_PASSWORD}
- NEO4J_apoc_export_file_enabled=true
- NEO4J_server_bolt_listen__address=0.0.0.0:7687
- NEO4J_server_bolt_advertised__address=neo4j:7687
volumes:
- ./neo4j/plugins:/var/lib/neo4j/plugins
- lightrag_neo4j_import:/var/lib/neo4j/import
- lightrag_neo4j_data:/data
- lightrag_neo4j_backups:/backups
volumes:
lightrag_neo4j_import:
lightrag_neo4j_data:
lightrag_neo4j_backups:

View File

@@ -361,6 +361,26 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist or None if the node doesn't exist
""" """
@abstractmethod
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND"""
@abstractmethod
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""Node degrees as a batch using UNWIND"""
@abstractmethod
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch"""
@abstractmethod
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
"""Get edges as a batch using UNWIND"""
@abstractmethod
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
""""Get nodes edges as a batch using UNWIND"""
@abstractmethod @abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph. """Insert a new node or update an existing node in the graph.

View File

@@ -308,6 +308,37 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error getting node for {node_id}: {str(e)}") logger.error(f"Error getting node for {node_id}: {str(e)}")
raise raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, n
"""
result = await session.run(query, node_ids=node_ids)
nodes = {}
async for record in result:
entity_id = record["entity_id"]
node = record["n"]
node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully
return nodes
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label. """Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node. If multiple nodes have the same label, returns the degree of the first node.
@@ -351,6 +382,41 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error getting node degree for {node_id}: {str(e)}") logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise raise
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
"""
result = await session.run(query, node_ids=node_ids)
degrees = {}
async for record in result:
entity_id = record["entity_id"]
degrees[entity_id] = record["degree"]
await result.consume() # Ensure result is fully consumed
# For any node_id that did not return a record, set degree to 0.
for nid in node_ids:
if nid not in degrees:
logger.warning(f"No node found with label '{nid}'")
degrees[nid] = 0
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
return degrees
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes. """Get the total degree (sum of relationships) of two nodes.
@@ -370,6 +436,30 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edge_pairs: List of (src, tgt) tuples.
Returns:
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
"""
# Collect unique node IDs from all edge pairs.
unique_node_ids = {src for src, _ in edge_pairs}
unique_node_ids.update({tgt for _, tgt in edge_pairs})
# Get degrees for all nodes in one go.
degrees = await self.node_degrees_batch(list(unique_node_ids))
# Sum up degrees for each edge pair.
edge_degrees = {}
for src, tgt in edge_pairs:
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
@@ -457,6 +547,43 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $pairs AS pair
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
"""
result = await session.run(query, pairs=pairs)
edges_dict = {}
async for record in result:
src = record["src_id"]
tgt = record["tgt_id"]
edges = record["edges"]
if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
if key not in edge_props:
edge_props[key] = default
edges_dict[(src, tgt)] = edge_props
else:
# No edge found set default edge properties
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
await result.consume()
return edges_dict
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label. """Retrieves all edges (relationships) for a particular node identified by its label.
@@ -517,6 +644,35 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise raise
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes in one query using UNWIND.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
OPTIONAL MATCH (n)-[r]-(connected:base)
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
"""
result = await session.run(query, node_ids=node_ids)
# Initialize the dictionary with empty lists for each node ID
edges_dict = {node_id: [] for node_id in node_ids}
async for record in result:
queried_id = record["queried_id"]
source_label = record["source_entity_id"]
target_label = record["target_entity_id"]
if source_label and target_label:
edges_dict[queried_id].append((source_label, target_label))
await result.consume() # Ensure results are fully consumed
return edges_dict
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),

View File

@@ -1323,16 +1323,20 @@ async def _get_node_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
# get entity information
node_datas, node_degrees = await asyncio.gather( # Extract all entity IDs from your results list
asyncio.gather( node_ids = [r["entity_name"] for r in results]
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
), # Call the batch node retrieval and degree functions concurrently.
asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] knowledge_graph_inst.get_nodes_batch(node_ids),
), knowledge_graph_inst.node_degrees_batch(node_ids)
) )
# Now, if you need the node data and degree in order:
node_datas = [nodes_dict.get(nid) for nid in node_ids]
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
if not all([n is not None for n in node_datas]): if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged") logger.warning("Some nodes are missing, maybe the storage is damaged")
@@ -1455,9 +1459,12 @@ async def _find_most_related_text_unit_from_entities(
for dp in node_datas for dp in node_datas
if dp["source_id"] is not None if dp["source_id"] is not None
] ]
edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas] node_names = [dp["entity_name"] for dp in node_datas]
) batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
# Build the edges list in the same order as node_datas.
edges = [batch_edges_dict.get(name, []) for name in node_names]
all_one_hop_nodes = set() all_one_hop_nodes = set()
for this_edges in edges: for this_edges in edges:
if not this_edges: if not this_edges:
@@ -1465,9 +1472,10 @@ async def _find_most_related_text_unit_from_entities(
all_one_hop_nodes.update([e[1] for e in this_edges]) all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes) all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await asyncio.gather(
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes] # Batch retrieve one-hop node data using get_nodes_batch
) all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
# Add null check for node data # Add null check for node data
all_one_hop_text_units_lookup = { all_one_hop_text_units_lookup = {
@@ -1558,17 +1566,31 @@ async def _find_most_related_edges_from_entities(
seen.add(sorted_edge) seen.add(sorted_edge)
all_edges.append(sorted_edge) all_edges.append(sorted_edge)
all_edges_pack, all_edges_degree = await asyncio.gather( # Prepare edge pairs in two forms:
asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]), # For the batch edge properties function, use dicts.
asyncio.gather( edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges] # For edge degrees, use tuples.
), edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
) )
all_edges_data = [
{"src_tgt": k, "rank": d, **v} # Reconstruct edge_datas list in the same order as the deduplicated results.
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree) all_edges_data = []
if v is not None for pair in all_edges:
] edge_props = edge_data_dict.get(pair)
if edge_props is not None:
combined = {
"src_tgt": pair,
"rank": edge_degrees_dict.get(pair, 0),
**edge_props,
}
all_edges_data.append(combined)
all_edges_data = sorted( all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1603,29 +1625,34 @@ async def _get_edge_data(
if not len(results): if not len(results):
return "", "", "" return "", "", ""
edge_datas, edge_degree = await asyncio.gather( # Prepare edge pairs in two forms:
asyncio.gather( # For the batch edge properties function, use dicts.
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
), # For edge degrees, use tuples.
asyncio.gather( edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
*[
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) # Call the batched functions concurrently.
for r in results edge_data_dict, edge_degrees_dict = await asyncio.gather(
] knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
), knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
) )
edge_datas = [ # Reconstruct edge_datas list in the same order as results.
{ edge_datas = []
"src_id": k["src_id"], for k in results:
"tgt_id": k["tgt_id"], pair = (k["src_id"], k["tgt_id"])
"rank": d, edge_props = edge_data_dict.get(pair)
"created_at": k.get("__created_at__", None), if edge_props is not None:
**v, # Use edge degree from the batch as rank.
} combined = {
for k, v, d in zip(results, edge_datas, edge_degree) "src_id": k["src_id"],
if v is not None "tgt_id": k["tgt_id"],
] "rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
"created_at": k.get("__created_at__", None),
**edge_props,
}
edge_datas.append(combined)
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1731,25 +1758,23 @@ async def _find_most_related_entities_from_relationships(
entity_names.append(e["tgt_id"]) entity_names.append(e["tgt_id"])
seen.add(e["tgt_id"]) seen.add(e["tgt_id"])
node_datas, node_degrees = await asyncio.gather( # Batch approach: Retrieve nodes and their degrees concurrently with one query each.
asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
*[ knowledge_graph_inst.get_nodes_batch(entity_names),
knowledge_graph_inst.get_node(entity_name) knowledge_graph_inst.node_degrees_batch(entity_names)
for entity_name in entity_names
]
),
asyncio.gather(
*[
knowledge_graph_inst.node_degree(entity_name)
for entity_name in entity_names
]
),
) )
node_datas = [
{**n, "entity_name": k, "rank": d} # Rebuild the list in the same order as entity_names
for k, n, d in zip(entity_names, node_datas, node_degrees) node_datas = []
if n is not None for entity_name in entity_names:
] node = nodes_dict.get(entity_name)
degree = degrees_dict.get(entity_name, 0)
if node is None:
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
continue
# Combine the node data with the entity name and computed degree (as rank)
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
len_node_datas = len(node_datas) len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size( node_datas = truncate_list_by_token_size(