Fix edge direction problem for Neo4j storage
This commit is contained in:
@@ -665,31 +665,56 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
) -> dict[str, list[tuple[str, str]]]:
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
"""
|
"""
|
||||||
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
||||||
|
For each node, returns both outgoing and incoming edges to properly represent
|
||||||
|
the undirected graph nature.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||||
|
For each node, the list includes both:
|
||||||
|
- Outgoing edges: (queried_node, connected_node)
|
||||||
|
- Incoming edges: (connected_node, queried_node)
|
||||||
"""
|
"""
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
|
# Query to get both outgoing and incoming edges
|
||||||
query = """
|
query = """
|
||||||
UNWIND $node_ids AS id
|
UNWIND $node_ids AS id
|
||||||
MATCH (n:base {entity_id: id})
|
MATCH (n:base {entity_id: id})
|
||||||
OPTIONAL MATCH (n)-[r]-(connected:base)
|
OPTIONAL MATCH (n)-[r]-(connected:base)
|
||||||
RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id
|
RETURN id AS queried_id, n.entity_id AS node_entity_id,
|
||||||
|
connected.entity_id AS connected_entity_id,
|
||||||
|
startNode(r).entity_id AS start_entity_id
|
||||||
"""
|
"""
|
||||||
result = await session.run(query, node_ids=node_ids)
|
result = await session.run(query, node_ids=node_ids)
|
||||||
|
|
||||||
# Initialize the dictionary with empty lists for each node ID
|
# Initialize the dictionary with empty lists for each node ID
|
||||||
edges_dict = {node_id: [] for node_id in node_ids}
|
edges_dict = {node_id: [] for node_id in node_ids}
|
||||||
|
|
||||||
|
# Process results to include both outgoing and incoming edges
|
||||||
async for record in result:
|
async for record in result:
|
||||||
queried_id = record["queried_id"]
|
queried_id = record["queried_id"]
|
||||||
source_label = record["source_entity_id"]
|
node_entity_id = record["node_entity_id"]
|
||||||
target_label = record["target_entity_id"]
|
connected_entity_id = record["connected_entity_id"]
|
||||||
if source_label and target_label:
|
start_entity_id = record["start_entity_id"]
|
||||||
edges_dict[queried_id].append((source_label, target_label))
|
|
||||||
|
# Skip if either node is None
|
||||||
|
if not node_entity_id or not connected_entity_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine the actual direction of the edge
|
||||||
|
# If the start node is the queried node, it's an outgoing edge
|
||||||
|
# Otherwise, it's an incoming edge
|
||||||
|
if start_entity_id == node_entity_id:
|
||||||
|
# Outgoing edge: (queried_node -> connected_node)
|
||||||
|
edges_dict[queried_id].append((node_entity_id, connected_entity_id))
|
||||||
|
else:
|
||||||
|
# Incoming edge: (connected_node -> queried_node)
|
||||||
|
edges_dict[queried_id].append((connected_entity_id, node_entity_id))
|
||||||
|
|
||||||
await result.consume() # Ensure results are fully consumed
|
await result.consume() # Ensure results are fully consumed
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user