diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index fe01aaf3..39b1bd57 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -382,26 +382,35 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session(database=self._DATABASE) as session: try: - # Critical debug step: first verify if starting node exists - validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" - validate_result = await session.run(validate_query) - if not await validate_result.single(): - logger.warning(f"Starting node {label} does not exist!") - return result + main_query = "" + if label == '*': + main_query = """ + MATCH (n) + WITH collect(DISTINCT n) AS nodes + MATCH ()-[r]-() + RETURN nodes, collect(DISTINCT r) AS relationships; + """ + else: + # Critical debug step: first verify if starting node exists + validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" + validate_result = await session.run(validate_query) + if not await validate_result.single(): + logger.warning(f"Starting node {label} does not exist!") + return result - # Optimized query (including direction handling and self-loops) - main_query = f""" - MATCH (start:`{label}`) - WITH start - CALL apoc.path.subgraphAll(start, {{ - relationshipFilter: '>', - minLevel: 0, - maxLevel: {max_depth}, - bfs: true - }}) - YIELD nodes, relationships - RETURN nodes, relationships - """ + # Optimized query (including direction handling and self-loops) + main_query = f""" + MATCH (start:`{label}`) + WITH start + CALL apoc.path.subgraphAll(start, {{ + relationshipFilter: '>', + minLevel: 0, + maxLevel: {max_depth}, + bfs: true + }}) + YIELD nodes, relationships + RETURN nodes, relationships + """ result_set = await session.run(main_query) record = await result_set.single() @@ -409,28 +418,29 @@ class Neo4JStorage(BaseGraphStorage): # Handle nodes (compatible with multi-label cases) for node in record["nodes"]: # Use node ID + label combination as unique identifier - node_id = f"{node.id}_{'_'.join(node.labels)}" + node_id = node.id if node_id not in seen_nodes: - node_data = dict(node) + node_data = {} node_data["labels"] = list(node.labels) # Keep all labels + node_data["id"] = f"{node_id}" + node_data["properties"] = dict(node) result["nodes"].append(node_data) seen_nodes.add(node_id) # Handle relationships (including direction information) for rel in record["relationships"]: - edge_id = f"{rel.id}_{rel.type}" + edge_id = rel.id if edge_id not in seen_edges: start = rel.start_node end = rel.end_node - edge_data = dict(rel) + edge_data = {} edge_data.update( { - "source": f"{start.id}_{'_'.join(start.labels)}", - "target": f"{end.id}_{'_'.join(end.labels)}", + "source": f"{start.id}", + "target": f"{end.id}", "type": rel.type, - "direction": rel.element_id.split( - "->" if rel.end_node == end else "<-" - )[1], + "id": f"{edge_id}", + "properties": dict(rel), } ) result["edges"].append(edge_data)