Fix edge direction problem for Neo4j storage
This commit is contained in:
@@ -665,31 +665,56 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
||||
For each node, returns both outgoing and incoming edges to properly represent
|
||||
the undirected graph nature.
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||
For each node, the list includes both:
|
||||
- Outgoing edges: (queried_node, connected_node)
|
||||
- Incoming edges: (connected_node, queried_node)
|
||||
"""
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
# Query to get both outgoing and incoming edges
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
OPTIONAL MATCH (n)-[r]-(connected:base)
|
||||
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
|
||||
RETURN id AS queried_id, n.entity_id AS node_entity_id,
|
||||
connected.entity_id AS connected_entity_id,
|
||||
startNode(r).entity_id AS start_entity_id
|
||||
"""
|
||||
result = await session.run(query, node_ids=node_ids)
|
||||
|
||||
# Initialize the dictionary with empty lists for each node ID
|
||||
edges_dict = {node_id: [] for node_id in node_ids}
|
||||
|
||||
# Process results to include both outgoing and incoming edges
|
||||
async for record in result:
|
||||
queried_id = record["queried_id"]
|
||||
source_label = record["source_entity_id"]
|
||||
target_label = record["target_entity_id"]
|
||||
if source_label and target_label:
|
||||
edges_dict[queried_id].append((source_label, target_label))
|
||||
node_entity_id = record["node_entity_id"]
|
||||
connected_entity_id = record["connected_entity_id"]
|
||||
start_entity_id = record["start_entity_id"]
|
||||
|
||||
# Skip if either node is None
|
||||
if not node_entity_id or not connected_entity_id:
|
||||
continue
|
||||
|
||||
# Determine the actual direction of the edge
|
||||
# If the start node is the queried node, it's an outgoing edge
|
||||
# Otherwise, it's an incoming edge
|
||||
if start_entity_id == node_entity_id:
|
||||
# Outgoing edge: (queried_node -> connected_node)
|
||||
edges_dict[queried_id].append((node_entity_id, connected_entity_id))
|
||||
else:
|
||||
# Incoming edge: (connected_node -> queried_node)
|
||||
edges_dict[queried_id].append((connected_entity_id, node_entity_id))
|
||||
|
||||
await result.consume() # Ensure results are fully consumed
|
||||
return edges_dict
|
||||
|
||||
|
Reference in New Issue
Block a user