Merge branch 'graph-storage-batch-query-frederikhendrix' into graph-storage-batch-query

This commit is contained in:
yangdx
2025-04-12 22:20:41 +08:00
4 changed files with 287 additions and 63 deletions

View File

@@ -11,3 +11,26 @@ services:
env_file:
- .env
restart: unless-stopped
neo4j:
image: neo4j:5.26.4-community
container_name: lightrag-server_neo4j-community
restart: always
ports:
- "7474:7474"
- "7687:7687"
environment:
- NEO4J_AUTH=${NEO4J_USERNAME}/${NEO4J_PASSWORD}
- NEO4J_apoc_export_file_enabled=true
- NEO4J_server_bolt_listen__address=0.0.0.0:7687
- NEO4J_server_bolt_advertised__address=neo4j:7687
volumes:
- ./neo4j/plugins:/var/lib/neo4j/plugins
- lightrag_neo4j_import:/var/lib/neo4j/import
- lightrag_neo4j_data:/data
- lightrag_neo4j_backups:/backups
volumes:
lightrag_neo4j_import:
lightrag_neo4j_data:
lightrag_neo4j_backups:

View File

@@ -361,6 +361,26 @@ class BaseGraphStorage(StorageNameSpace, ABC):
or None if the node doesn't exist
"""
@abstractmethod
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""Get nodes as a batch using UNWIND"""
@abstractmethod
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""Node degrees as a batch using UNWIND"""
@abstractmethod
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch"""
@abstractmethod
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
"""Get edges as a batch using UNWIND"""
@abstractmethod
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
""""Get nodes edges as a batch using UNWIND"""
@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph.

View File

@@ -308,6 +308,37 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error getting node for {node_id}: {str(e)}")
raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, n
"""
result = await session.run(query, node_ids=node_ids)
nodes = {}
async for record in result:
entity_id = record["entity_id"]
node = record["n"]
node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully
return nodes
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
@@ -351,6 +382,41 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
"""
result = await session.run(query, node_ids=node_ids)
degrees = {}
async for record in result:
entity_id = record["entity_id"]
degrees[entity_id] = record["degree"]
await result.consume() # Ensure result is fully consumed
# For any node_id that did not return a record, set degree to 0.
for nid in node_ids:
if nid not in degrees:
logger.warning(f"No node found with label '{nid}'")
degrees[nid] = 0
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
return degrees
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes.
@@ -371,6 +437,30 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree)
return degrees
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edge_pairs: List of (src, tgt) tuples.
Returns:
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
"""
# Collect unique node IDs from all edge pairs.
unique_node_ids = {src for src, _ in edge_pairs}
unique_node_ids.update({tgt for _, tgt in edge_pairs})
# Get degrees for all nodes in one go.
degrees = await self.node_degrees_batch(list(unique_node_ids))
# Sum up degrees for each edge pair.
edge_degrees = {}
for src, tgt in edge_pairs:
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
@@ -457,6 +547,43 @@ class Neo4JStorage(BaseGraphStorage):
)
raise
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $pairs AS pair
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
"""
result = await session.run(query, pairs=pairs)
edges_dict = {}
async for record in result:
src = record["src_id"]
tgt = record["tgt_id"]
edges = record["edges"]
if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
if key not in edge_props:
edge_props[key] = default
edges_dict[(src, tgt)] = edge_props
else:
# No edge found set default edge properties
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
await result.consume()
return edges_dict
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
@@ -517,6 +644,35 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes in one query using UNWIND.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
"""
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
OPTIONAL MATCH (n)-[r]-(connected:base)
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
"""
result = await session.run(query, node_ids=node_ids)
# Initialize the dictionary with empty lists for each node ID
edges_dict = {node_id: [] for node_id in node_ids}
async for record in result:
queried_id = record["queried_id"]
source_label = record["source_entity_id"]
target_label = record["target_entity_id"]
if source_label and target_label:
edges_dict[queried_id].append((source_label, target_label))
await result.consume() # Ensure results are fully consumed
return edges_dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),

View File

@@ -1323,16 +1323,20 @@ async def _get_node_data(
if not len(results):
return "", "", ""
# get entity information
node_datas, node_degrees = await asyncio.gather(
asyncio.gather(
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
),
asyncio.gather(
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
),
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
# Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids)
)
# Now, if you need the node data and degree in order:
node_datas = [nodes_dict.get(nid) for nid in node_ids]
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
if not all([n is not None for n in node_datas]):
logger.warning("Some nodes are missing, maybe the storage is damaged")
@@ -1455,9 +1459,12 @@ async def _find_most_related_text_unit_from_entities(
for dp in node_datas
if dp["source_id"] is not None
]
edges = await asyncio.gather(
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
)
node_names = [dp["entity_name"] for dp in node_datas]
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
# Build the edges list in the same order as node_datas.
edges = [batch_edges_dict.get(name, []) for name in node_names]
all_one_hop_nodes = set()
for this_edges in edges:
if not this_edges:
@@ -1465,9 +1472,10 @@ async def _find_most_related_text_unit_from_entities(
all_one_hop_nodes.update([e[1] for e in this_edges])
all_one_hop_nodes = list(all_one_hop_nodes)
all_one_hop_nodes_data = await asyncio.gather(
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
)
# Batch retrieve one-hop node data using get_nodes_batch
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
# Add null check for node data
all_one_hop_text_units_lookup = {
@@ -1558,17 +1566,31 @@ async def _find_most_related_edges_from_entities(
seen.add(sorted_edge)
all_edges.append(sorted_edge)
all_edges_pack, all_edges_degree = await asyncio.gather(
asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
asyncio.gather(
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
),
# Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts.
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
# For edge degrees, use tuples.
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
)
all_edges_data = [
{"src_tgt": k, "rank": d, **v}
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
if v is not None
]
# Reconstruct edge_datas list in the same order as the deduplicated results.
all_edges_data = []
for pair in all_edges:
edge_props = edge_data_dict.get(pair)
if edge_props is not None:
combined = {
"src_tgt": pair,
"rank": edge_degrees_dict.get(pair, 0),
**edge_props,
}
all_edges_data.append(combined)
all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1603,29 +1625,34 @@ async def _get_edge_data(
if not len(results):
return "", "", ""
edge_datas, edge_degree = await asyncio.gather(
asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
),
asyncio.gather(
*[
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
for r in results
]
),
# Prepare edge pairs in two forms:
# For the batch edge properties function, use dicts.
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
# For edge degrees, use tuples.
edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
# Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
)
edge_datas = [
{
# Reconstruct edge_datas list in the same order as results.
edge_datas = []
for k in results:
pair = (k["src_id"], k["tgt_id"])
edge_props = edge_data_dict.get(pair)
if edge_props is not None:
# Use edge degree from the batch as rank.
combined = {
"src_id": k["src_id"],
"tgt_id": k["tgt_id"],
"rank": d,
"rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
"created_at": k.get("__created_at__", None),
**v,
**edge_props,
}
for k, v, d in zip(results, edge_datas, edge_degree)
if v is not None
]
edge_datas.append(combined)
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
@@ -1731,25 +1758,23 @@ async def _find_most_related_entities_from_relationships(
entity_names.append(e["tgt_id"])
seen.add(e["tgt_id"])
node_datas, node_degrees = await asyncio.gather(
asyncio.gather(
*[
knowledge_graph_inst.get_node(entity_name)
for entity_name in entity_names
]
),
asyncio.gather(
*[
knowledge_graph_inst.node_degree(entity_name)
for entity_name in entity_names
]
),
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(entity_names),
knowledge_graph_inst.node_degrees_batch(entity_names)
)
node_datas = [
{**n, "entity_name": k, "rank": d}
for k, n, d in zip(entity_names, node_datas, node_degrees)
if n is not None
]
# Rebuild the list in the same order as entity_names
node_datas = []
for entity_name in entity_names:
node = nodes_dict.get(entity_name)
degree = degrees_dict.get(entity_name, 0)
if node is None:
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
continue
# Combine the node data with the entity name and computed degree (as rank)
combined = {**node, "entity_name": entity_name, "rank": degree}
node_datas.append(combined)
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(