add support for querying all nodes and relationships in Neo4j

This commit is contained in:
ArnoChen
2025-02-09 22:22:59 +08:00
parent c1d7fbe02b
commit 8d23ed16be

View File

@@ -382,26 +382,35 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
try: try:
# Critical debug step: first verify if starting node exists main_query = ""
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" if label == '*':
validate_result = await session.run(validate_query) main_query = """
if not await validate_result.single(): MATCH (n)
logger.warning(f"Starting node {label} does not exist!") WITH collect(DISTINCT n) AS nodes
return result MATCH ()-[r]-()
RETURN nodes, collect(DISTINCT r) AS relationships;
"""
else:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
return result
# Optimized query (including direction handling and self-loops) # Optimized query (including direction handling and self-loops)
main_query = f""" main_query = f"""
MATCH (start:`{label}`) MATCH (start:`{label}`)
WITH start WITH start
CALL apoc.path.subgraphAll(start, {{ CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>', relationshipFilter: '>',
minLevel: 0, minLevel: 0,
maxLevel: {max_depth}, maxLevel: {max_depth},
bfs: true bfs: true
}}) }})
YIELD nodes, relationships YIELD nodes, relationships
RETURN nodes, relationships RETURN nodes, relationships
""" """
result_set = await session.run(main_query) result_set = await session.run(main_query)
record = await result_set.single() record = await result_set.single()
@@ -409,28 +418,29 @@ class Neo4JStorage(BaseGraphStorage):
# Handle nodes (compatible with multi-label cases) # Handle nodes (compatible with multi-label cases)
for node in record["nodes"]: for node in record["nodes"]:
# Use node ID + label combination as unique identifier # Use node ID + label combination as unique identifier
node_id = f"{node.id}_{'_'.join(node.labels)}" node_id = node.id
if node_id not in seen_nodes: if node_id not in seen_nodes:
node_data = dict(node) node_data = {}
node_data["labels"] = list(node.labels) # Keep all labels node_data["labels"] = list(node.labels) # Keep all labels
node_data["id"] = f"{node_id}"
node_data["properties"] = dict(node)
result["nodes"].append(node_data) result["nodes"].append(node_data)
seen_nodes.add(node_id) seen_nodes.add(node_id)
# Handle relationships (including direction information) # Handle relationships (including direction information)
for rel in record["relationships"]: for rel in record["relationships"]:
edge_id = f"{rel.id}_{rel.type}" edge_id = rel.id
if edge_id not in seen_edges: if edge_id not in seen_edges:
start = rel.start_node start = rel.start_node
end = rel.end_node end = rel.end_node
edge_data = dict(rel) edge_data = {}
edge_data.update( edge_data.update(
{ {
"source": f"{start.id}_{'_'.join(start.labels)}", "source": f"{start.id}",
"target": f"{end.id}_{'_'.join(end.labels)}", "target": f"{end.id}",
"type": rel.type, "type": rel.type,
"direction": rel.element_id.split( "id": f"{edge_id}",
"->" if rel.end_node == end else "<-" "properties": dict(rel),
)[1],
} }
) )
result["edges"].append(edge_data) result["edges"].append(edge_data)