Merge branch 'graph-storage-batch-query-frederikhendrix' into graph-storage-batch-query
This commit is contained in:
@@ -11,3 +11,26 @@ services:
|
||||
env_file:
|
||||
- .env
|
||||
restart: unless-stopped
|
||||
|
||||
neo4j:
|
||||
image: neo4j:5.26.4-community
|
||||
container_name: lightrag-server_neo4j-community
|
||||
restart: always
|
||||
ports:
|
||||
- "7474:7474"
|
||||
- "7687:7687"
|
||||
environment:
|
||||
- NEO4J_AUTH=${NEO4J_USERNAME}/${NEO4J_PASSWORD}
|
||||
- NEO4J_apoc_export_file_enabled=true
|
||||
- NEO4J_server_bolt_listen__address=0.0.0.0:7687
|
||||
- NEO4J_server_bolt_advertised__address=neo4j:7687
|
||||
volumes:
|
||||
- ./neo4j/plugins:/var/lib/neo4j/plugins
|
||||
- lightrag_neo4j_import:/var/lib/neo4j/import
|
||||
- lightrag_neo4j_data:/data
|
||||
- lightrag_neo4j_backups:/backups
|
||||
|
||||
volumes:
|
||||
lightrag_neo4j_import:
|
||||
lightrag_neo4j_data:
|
||||
lightrag_neo4j_backups:
|
||||
|
@@ -361,6 +361,26 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
or None if the node doesn't exist
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""Get nodes as a batch using UNWIND"""
|
||||
|
||||
@abstractmethod
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""Node degrees as a batch using UNWIND"""
|
||||
|
||||
@abstractmethod
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
"""Get edges as a batch using UNWIND"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
""""Get nodes edges as a batch using UNWIND"""
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
"""Insert a new node or update an existing node in the graph.
|
||||
|
@@ -308,6 +308,37 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error getting node for {node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
Retrieve multiple nodes in one query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node entity IDs to fetch.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its node data (or None if not found).
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
RETURN n.entity_id AS entity_id, n
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
nodes = {}
|
||||
async for record in result:
|
||||
entity_id = record["entity_id"]
|
||||
node = record["n"]
|
||||
node_dict = dict(node)
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
||||
nodes[entity_id] = node_dict
|
||||
await result.consume() # Make sure to consume the result fully
|
||||
return nodes
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""Get the degree (number of relationships) of a node with the given label.
|
||||
If multiple nodes have the same label, returns the degree of the first node.
|
||||
@@ -351,6 +382,41 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node labels (entity_id values) to look up.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
If a node is not found, its degree will be set to 0.
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
degrees = {}
|
||||
async for record in result:
|
||||
entity_id = record["entity_id"]
|
||||
degrees[entity_id] = record["degree"]
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
# For any node_id that did not return a record, set degree to 0.
|
||||
for nid in node_ids:
|
||||
if nid not in degrees:
|
||||
logger.warning(f"No node found with label '{nid}'")
|
||||
degrees[nid] = 0
|
||||
|
||||
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
|
||||
return degrees
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""Get the total degree (sum of relationships) of two nodes.
|
||||
|
||||
@@ -371,6 +437,30 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
return degrees
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
"""
|
||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||
in batch using the already implemented node_degrees_batch.
|
||||
|
||||
Args:
|
||||
edge_pairs: List of (src, tgt) tuples.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
|
||||
"""
|
||||
# Collect unique node IDs from all edge pairs.
|
||||
unique_node_ids = {src for src, _ in edge_pairs}
|
||||
unique_node_ids.update({tgt for _, tgt in edge_pairs})
|
||||
|
||||
# Get degrees for all nodes in one go.
|
||||
degrees = await self.node_degrees_batch(list(unique_node_ids))
|
||||
|
||||
# Sum up degrees for each edge pair.
|
||||
edge_degrees = {}
|
||||
for src, tgt in edge_pairs:
|
||||
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
||||
return edge_degrees
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
@@ -457,6 +547,43 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
"""
|
||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||
|
||||
Args:
|
||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||
|
||||
Returns:
|
||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $pairs AS pair
|
||||
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
|
||||
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
|
||||
"""
|
||||
result = await session.run(query, pairs=pairs)
|
||||
edges_dict = {}
|
||||
async for record in result:
|
||||
src = record["src_id"]
|
||||
tgt = record["tgt_id"]
|
||||
edges = record["edges"]
|
||||
if edges and len(edges) > 0:
|
||||
edge_props = edges[0] # choose the first if multiple exist
|
||||
# Ensure required keys exist with defaults
|
||||
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
|
||||
if key not in edge_props:
|
||||
edge_props[key] = default
|
||||
edges_dict[(src, tgt)] = edge_props
|
||||
else:
|
||||
# No edge found – set default edge properties
|
||||
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
|
||||
await result.consume()
|
||||
return edges_dict
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
|
||||
@@ -517,6 +644,35 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
OPTIONAL MATCH (n)-[r]-(connected:base)
|
||||
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
# Initialize the dictionary with empty lists for each node ID
|
||||
edges_dict = {node_id: [] for node_id in node_ids}
|
||||
async for record in result:
|
||||
queried_id = record["queried_id"]
|
||||
source_label = record["source_entity_id"]
|
||||
target_label = record["target_entity_id"]
|
||||
if source_label and target_label:
|
||||
edges_dict[queried_id].append((source_label, target_label))
|
||||
await result.consume() # Ensure results are fully consumed
|
||||
return edges_dict
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
|
@@ -1323,16 +1323,20 @@ async def _get_node_data(
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
# get entity information
|
||||
node_datas, node_degrees = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
||||
),
|
||||
|
||||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
# Call the batch node retrieval and degree functions concurrently.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids)
|
||||
)
|
||||
|
||||
# Now, if you need the node data and degree in order:
|
||||
node_datas = [nodes_dict.get(nid) for nid in node_ids]
|
||||
node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids]
|
||||
|
||||
if not all([n is not None for n in node_datas]):
|
||||
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
||||
|
||||
@@ -1455,9 +1459,12 @@ async def _find_most_related_text_unit_from_entities(
|
||||
for dp in node_datas
|
||||
if dp["source_id"] is not None
|
||||
]
|
||||
edges = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
||||
)
|
||||
|
||||
node_names = [dp["entity_name"] for dp in node_datas]
|
||||
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
||||
# Build the edges list in the same order as node_datas.
|
||||
edges = [batch_edges_dict.get(name, []) for name in node_names]
|
||||
|
||||
all_one_hop_nodes = set()
|
||||
for this_edges in edges:
|
||||
if not this_edges:
|
||||
@@ -1465,9 +1472,10 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_one_hop_nodes.update([e[1] for e in this_edges])
|
||||
|
||||
all_one_hop_nodes = list(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = await asyncio.gather(
|
||||
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
|
||||
)
|
||||
|
||||
# Batch retrieve one-hop node data using get_nodes_batch
|
||||
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
|
||||
|
||||
# Add null check for node data
|
||||
all_one_hop_text_units_lookup = {
|
||||
@@ -1558,17 +1566,31 @@ async def _find_most_related_edges_from_entities(
|
||||
seen.add(sorted_edge)
|
||||
all_edges.append(sorted_edge)
|
||||
|
||||
all_edges_pack, all_edges_degree = await asyncio.gather(
|
||||
asyncio.gather(*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]),
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
|
||||
),
|
||||
# Prepare edge pairs in two forms:
|
||||
# For the batch edge properties function, use dicts.
|
||||
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
|
||||
# For edge degrees, use tuples.
|
||||
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
|
||||
|
||||
# Call the batched functions concurrently.
|
||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
||||
)
|
||||
all_edges_data = [
|
||||
{"src_tgt": k, "rank": d, **v}
|
||||
for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
|
||||
if v is not None
|
||||
]
|
||||
|
||||
# Reconstruct edge_datas list in the same order as the deduplicated results.
|
||||
all_edges_data = []
|
||||
for pair in all_edges:
|
||||
edge_props = edge_data_dict.get(pair)
|
||||
if edge_props is not None:
|
||||
combined = {
|
||||
"src_tgt": pair,
|
||||
"rank": edge_degrees_dict.get(pair, 0),
|
||||
**edge_props,
|
||||
}
|
||||
all_edges_data.append(combined)
|
||||
|
||||
|
||||
all_edges_data = sorted(
|
||||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1603,29 +1625,34 @@ async def _get_edge_data(
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
|
||||
edge_datas, edge_degree = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"])
|
||||
for r in results
|
||||
]
|
||||
),
|
||||
# Prepare edge pairs in two forms:
|
||||
# For the batch edge properties function, use dicts.
|
||||
edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
|
||||
# For edge degrees, use tuples.
|
||||
edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results]
|
||||
|
||||
# Call the batched functions concurrently.
|
||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
||||
)
|
||||
|
||||
edge_datas = [
|
||||
{
|
||||
"src_id": k["src_id"],
|
||||
"tgt_id": k["tgt_id"],
|
||||
"rank": d,
|
||||
"created_at": k.get("__created_at__", None),
|
||||
**v,
|
||||
}
|
||||
for k, v, d in zip(results, edge_datas, edge_degree)
|
||||
if v is not None
|
||||
]
|
||||
# Reconstruct edge_datas list in the same order as results.
|
||||
edge_datas = []
|
||||
for k in results:
|
||||
pair = (k["src_id"], k["tgt_id"])
|
||||
edge_props = edge_data_dict.get(pair)
|
||||
if edge_props is not None:
|
||||
# Use edge degree from the batch as rank.
|
||||
combined = {
|
||||
"src_id": k["src_id"],
|
||||
"tgt_id": k["tgt_id"],
|
||||
"rank": edge_degrees_dict.get(pair, k.get("rank", 0)),
|
||||
"created_at": k.get("__created_at__", None),
|
||||
**edge_props,
|
||||
}
|
||||
edge_datas.append(combined)
|
||||
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1731,25 +1758,23 @@ async def _find_most_related_entities_from_relationships(
|
||||
entity_names.append(e["tgt_id"])
|
||||
seen.add(e["tgt_id"])
|
||||
|
||||
node_datas, node_degrees = await asyncio.gather(
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.get_node(entity_name)
|
||||
for entity_name in entity_names
|
||||
]
|
||||
),
|
||||
asyncio.gather(
|
||||
*[
|
||||
knowledge_graph_inst.node_degree(entity_name)
|
||||
for entity_name in entity_names
|
||||
]
|
||||
),
|
||||
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(entity_names),
|
||||
knowledge_graph_inst.node_degrees_batch(entity_names)
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k, "rank": d}
|
||||
for k, n, d in zip(entity_names, node_datas, node_degrees)
|
||||
if n is not None
|
||||
]
|
||||
|
||||
# Rebuild the list in the same order as entity_names
|
||||
node_datas = []
|
||||
for entity_name in entity_names:
|
||||
node = nodes_dict.get(entity_name)
|
||||
degree = degrees_dict.get(entity_name, 0)
|
||||
if node is None:
|
||||
logger.warning(f"Node '{entity_name}' not found in batch retrieval.")
|
||||
continue
|
||||
# Combine the node data with the entity name and computed degree (as rank)
|
||||
combined = {**node, "entity_name": entity_name, "rank": degree}
|
||||
node_datas.append(combined)
|
||||
|
||||
len_node_datas = len(node_datas)
|
||||
node_datas = truncate_list_by_token_size(
|
||||
|
Reference in New Issue
Block a user