add support for querying all nodes and relationships in Neo4j

This commit is contained in:
ArnoChen
2025-02-09 22:22:59 +08:00
parent c1d7fbe02b
commit 8d23ed16be

View File

@@ -382,6 +382,15 @@ class Neo4JStorage(BaseGraphStorage):
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
try: try:
main_query = ""
if label == '*':
main_query = """
MATCH (n)
WITH collect(DISTINCT n) AS nodes
MATCH ()-[r]-()
RETURN nodes, collect(DISTINCT r) AS relationships;
"""
else:
# Critical debug step: first verify if starting node exists # Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_result = await session.run(validate_query) validate_result = await session.run(validate_query)
@@ -409,28 +418,29 @@ class Neo4JStorage(BaseGraphStorage):
# Handle nodes (compatible with multi-label cases) # Handle nodes (compatible with multi-label cases)
for node in record["nodes"]: for node in record["nodes"]:
# Use node ID + label combination as unique identifier # Use node ID + label combination as unique identifier
node_id = f"{node.id}_{'_'.join(node.labels)}" node_id = node.id
if node_id not in seen_nodes: if node_id not in seen_nodes:
node_data = dict(node) node_data = {}
node_data["labels"] = list(node.labels) # Keep all labels node_data["labels"] = list(node.labels) # Keep all labels
node_data["id"] = f"{node_id}"
node_data["properties"] = dict(node)
result["nodes"].append(node_data) result["nodes"].append(node_data)
seen_nodes.add(node_id) seen_nodes.add(node_id)
# Handle relationships (including direction information) # Handle relationships (including direction information)
for rel in record["relationships"]: for rel in record["relationships"]:
edge_id = f"{rel.id}_{rel.type}" edge_id = rel.id
if edge_id not in seen_edges: if edge_id not in seen_edges:
start = rel.start_node start = rel.start_node
end = rel.end_node end = rel.end_node
edge_data = dict(rel) edge_data = {}
edge_data.update( edge_data.update(
{ {
"source": f"{start.id}_{'_'.join(start.labels)}", "source": f"{start.id}",
"target": f"{end.id}_{'_'.join(end.labels)}", "target": f"{end.id}",
"type": rel.type, "type": rel.type,
"direction": rel.element_id.split( "id": f"{edge_id}",
"->" if rel.end_node == end else "<-" "properties": dict(rel),
)[1],
} }
) )
result["edges"].append(edge_data) result["edges"].append(edge_data)