From 33e8b9f2846ca2e8c28ad9e468a84a1759df3385 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 16 Apr 2025 14:24:28 +0800 Subject: [PATCH] Fix edge direction problem for Neo4j storage --- lightrag/kg/neo4j_impl.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 1b712462..06abfe02 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -665,31 +665,56 @@ class Neo4JStorage(BaseGraphStorage): ) -> dict[str, list[tuple[str, str]]]: """ Batch retrieve edges for multiple nodes in one query using UNWIND. + For each node, returns both outgoing and incoming edges to properly represent + the undirected graph nature. Args: node_ids: List of node IDs (entity_id) for which to retrieve edges. Returns: A dictionary mapping each node ID to its list of edge tuples (source, target). + For each node, the list includes both: + - Outgoing edges: (queried_node, connected_node) + - Incoming edges: (connected_node, queried_node) """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: + # Query to get both outgoing and incoming edges query = """ UNWIND $node_ids AS id MATCH (n:base {entity_id: id}) OPTIONAL MATCH (n)-[r]-(connected:base) - RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id + RETURN id AS queried_id, n.entity_id AS node_entity_id, + connected.entity_id AS connected_entity_id, + startNode(r).entity_id AS start_entity_id """ result = await session.run(query, node_ids=node_ids) + # Initialize the dictionary with empty lists for each node ID edges_dict = {node_id: [] for node_id in node_ids} + + # Process results to include both outgoing and incoming edges async for record in result: queried_id = record["queried_id"] - source_label = record["source_entity_id"] - target_label = record["target_entity_id"] - if source_label and target_label: - edges_dict[queried_id].append((source_label, target_label)) + node_entity_id = record["node_entity_id"] + connected_entity_id = record["connected_entity_id"] + start_entity_id = record["start_entity_id"] + + # Skip if either node is None + if not node_entity_id or not connected_entity_id: + continue + + # Determine the actual direction of the edge + # If the start node is the queried node, it's an outgoing edge + # Otherwise, it's an incoming edge + if start_entity_id == node_entity_id: + # Outgoing edge: (queried_node -> connected_node) + edges_dict[queried_id].append((node_entity_id, connected_entity_id)) + else: + # Incoming edge: (connected_node -> queried_node) + edges_dict[queried_id].append((connected_entity_id, node_entity_id)) + await result.consume() # Ensure results are fully consumed return edges_dict