Fix edge direction problem for Neo4j storage

This commit is contained in:
yangdx
2025-04-16 14:24:28 +08:00
parent 2a950f3ff9
commit 33e8b9f284

View File

@@ -665,31 +665,56 @@ class Neo4JStorage(BaseGraphStorage):
) -> dict[str, list[tuple[str, str]]]: ) -> dict[str, list[tuple[str, str]]]:
""" """
Batch retrieve edges for multiple nodes in one query using UNWIND. Batch retrieve edges for multiple nodes in one query using UNWIND.
For each node, returns both outgoing and incoming edges to properly represent
the undirected graph nature.
Args: Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges. node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns: Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target). A dictionary mapping each node ID to its list of edge tuples (source, target).
For each node, the list includes both:
- Outgoing edges: (queried_node, connected_node)
- Incoming edges: (connected_node, queried_node)
""" """
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
# Query to get both outgoing and incoming edges
query = """ query = """
UNWIND $node_ids AS id UNWIND $node_ids AS id
MATCH (n:base {entity_id: id}) MATCH (n:base {entity_id: id})
OPTIONAL MATCH (n)-[r]-(connected:base) OPTIONAL MATCH (n)-[r]-(connected:base)
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id RETURN id AS queried_id, n.entity_id AS node_entity_id,
connected.entity_id AS connected_entity_id,
startNode(r).entity_id AS start_entity_id
""" """
result = await session.run(query, node_ids=node_ids) result = await session.run(query, node_ids=node_ids)
# Initialize the dictionary with empty lists for each node ID # Initialize the dictionary with empty lists for each node ID
edges_dict = {node_id: [] for node_id in node_ids} edges_dict = {node_id: [] for node_id in node_ids}
# Process results to include both outgoing and incoming edges
async for record in result: async for record in result:
queried_id = record["queried_id"] queried_id = record["queried_id"]
source_label = record["source_entity_id"] node_entity_id = record["node_entity_id"]
target_label = record["target_entity_id"] connected_entity_id = record["connected_entity_id"]
if source_label and target_label: start_entity_id = record["start_entity_id"]
edges_dict[queried_id].append((source_label, target_label))
# Skip if either node is None
if not node_entity_id or not connected_entity_id:
continue
# Determine the actual direction of the edge
# If the start node is the queried node, it's an outgoing edge
# Otherwise, it's an incoming edge
if start_entity_id == node_entity_id:
# Outgoing edge: (queried_node -> connected_node)
edges_dict[queried_id].append((node_entity_id, connected_entity_id))
else:
# Incoming edge: (connected_node -> queried_node)
edges_dict[queried_id].append((connected_entity_id, node_entity_id))
await result.consume() # Ensure results are fully consumed await result.consume() # Ensure results are fully consumed
return edges_dict return edges_dict