From b0035374298f21ef4109f0533c3a9083687c5d80 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 4 Apr 2025 02:05:29 +0800 Subject: [PATCH] Fix APOC fallback error --- lightrag/kg/neo4j_impl.py | 217 ++++++++++++++++++++++---------------- 1 file changed, 124 insertions(+), 93 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index ba53a349..b2c54d82 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -2,7 +2,7 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, final, Optional +from typing import Any, final import numpy as np import configparser @@ -713,6 +713,7 @@ class Neo4JStorage(BaseGraphStorage): await result_set.consume() else: + # return await self._robust_fallback(node_label, max_depth, max_nodes) # First try without limit to check if we need to truncate full_query = """ MATCH (start) @@ -855,98 +856,14 @@ class Neo4JStorage(BaseGraphStorage): """ Fallback implementation when APOC plugin is not available or incompatible. This method implements the same functionality as get_knowledge_graph but uses - only basic Cypher queries and recursive traversal instead of APOC procedures. + only basic Cypher queries and true breadth-first traversal instead of APOC procedures. """ + from collections import deque + result = KnowledgeGraph() visited_nodes = set() visited_edges = set() - - async def traverse( - node: KnowledgeGraphNode, - edge: Optional[KnowledgeGraphEdge], - current_depth: int, - ): - # Check traversal limits - if current_depth > max_depth: - logger.debug(f"Reached max depth: {max_depth}") - return - if len(visited_nodes) >= max_nodes: - # Set truncated flag when we hit the max_nodes limit - result.is_truncated = True - logger.info( - f"Graph truncated: breadth-first search limited to: {max_nodes} nodes" - ) - return - - # Check if node already visited - if node.id in visited_nodes: - return - - # Get all edges and target nodes - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = """ - MATCH (a:base {entity_id: $entity_id})-[r]-(b) - WITH r, b, id(r) as edge_id, id(b) as target_id - RETURN r, b, edge_id, target_id - """ - results = await session.run(query, entity_id=node.id) - - # Get all records and release database connection - records = await results.fetch( - 1000 - ) # Max neighbour nodes we can handled - await results.consume() # Ensure results are consumed - - # Nodes not connected to start node need to check degree - # if current_depth > 1 and len(records) < min_degree: - # return - - # Add current node to result - result.nodes.append(node) - visited_nodes.add(node.id) - - # Add edge to result if it exists and not already added - if edge and edge.id not in visited_edges: - result.edges.append(edge) - visited_edges.add(edge.id) - - # Prepare nodes and edges for recursive processing - nodes_to_process = [] - for record in records: - rel = record["r"] - edge_id = str(record["edge_id"]) - if edge_id not in visited_edges: - b_node = record["b"] - target_id = b_node.get("entity_id") - - if target_id: # Only process if target node has entity_id - # Create KnowledgeGraphNode for target - target_node = KnowledgeGraphNode( - id=f"{target_id}", - labels=list(f"{target_id}"), - properties=dict(b_node.properties), - ) - - # Create KnowledgeGraphEdge - target_edge = KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{node.id}", - target=f"{target_id}", - properties=dict(rel), - ) - - nodes_to_process.append((target_node, target_edge)) - else: - logger.warning( - f"Skipping edge {edge_id} due to missing labels on target node" - ) - - # Process nodes after releasing database connection - for target_node, target_edge in nodes_to_process: - await traverse(target_node, target_edge, current_depth + 1) + visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id) # Get the starting node's data async with self._driver.session( @@ -965,15 +882,129 @@ class Neo4JStorage(BaseGraphStorage): # Create initial KnowledgeGraphNode start_node = KnowledgeGraphNode( id=f"{node_record['n'].get('entity_id')}", - labels=list(f"{node_record['n'].get('entity_id')}"), - properties=dict(node_record["n"].properties), + labels=[node_record["n"].get("entity_id")], + properties=dict(node_record["n"]._properties), ) finally: await node_result.consume() # Ensure results are consumed - # Start traversal with the initial node - await traverse(start_node, None, 0) + # Initialize queue for BFS with (node, edge, depth) tuples + # edge is None for the starting node + queue = deque([(start_node, None, 0)]) + # True BFS implementation using a queue + while queue and len(visited_nodes) < max_nodes: + # Dequeue the next node to process + current_node, current_edge, current_depth = queue.popleft() + + # Skip if already visited or exceeds max depth + if current_node.id in visited_nodes: + continue + + if current_depth > max_depth: + logger.debug( + f"Skipping node at depth {current_depth} (max_depth: {max_depth})" + ) + continue + + # Add current node to result + result.nodes.append(current_node) + visited_nodes.add(current_node.id) + + # Add edge to result if it exists and not already added + if current_edge and current_edge.id not in visited_edges: + result.edges.append(current_edge) + visited_edges.add(current_edge.id) + + # Stop if we've reached the node limit + if len(visited_nodes) >= max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to: {max_nodes} nodes" + ) + break + + # Get all edges and target nodes for the current node (even at max_depth) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + MATCH (a:base {entity_id: $entity_id})-[r]-(b) + WITH r, b, id(r) as edge_id, id(b) as target_id + RETURN r, b, edge_id, target_id + """ + results = await session.run(query, entity_id=current_node.id) + + # Get all records and release database connection + records = await results.fetch(1000) # Max neighbor nodes we can handle + await results.consume() # Ensure results are consumed + + # Process all neighbors - capture all edges but only queue unvisited nodes + for record in records: + rel = record["r"] + edge_id = str(record["edge_id"]) + + if edge_id not in visited_edges: + b_node = record["b"] + target_id = b_node.get("entity_id") + + if target_id: # Only process if target node has entity_id + # Create KnowledgeGraphNode for target + target_node = KnowledgeGraphNode( + id=f"{target_id}", + labels=[target_id], + properties=dict(b_node._properties), + ) + + # Create KnowledgeGraphEdge + target_edge = KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{current_node.id}", + target=f"{target_id}", + properties=dict(rel), + ) + + # 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边 + sorted_pair = tuple(sorted([current_node.id, target_id])) + + # 检查是否已存在相同的边(考虑无向性) + if sorted_pair not in visited_edge_pairs: + # 只有当目标节点已经在结果中或将被添加到结果中时,才添加边 + if target_id in visited_nodes or ( + target_id not in visited_nodes + and current_depth < max_depth + ): + result.edges.append(target_edge) + visited_edges.add(edge_id) + visited_edge_pairs.add(sorted_pair) + + # Only add unvisited nodes to the queue for further expansion + if target_id not in visited_nodes: + # Only add to queue if we're not at max depth yet + if current_depth < max_depth: + # Add node to queue with incremented depth + # Edge is already added to result, so we pass None as edge + queue.append((target_node, None, current_depth + 1)) + else: + # At max depth, we've already added the edge but we don't add the node + # This prevents adding nodes beyond max_depth to the result + logger.debug( + f"Node {target_id} beyond max depth {max_depth}, edge added but node not included" + ) + else: + # If target node already exists in result, we don't need to add it again + logger.debug( + f"Node {target_id} already visited, edge added but node not queued" + ) + else: + logger.warning( + f"Skipping edge {edge_id} due to missing entity_id on target node" + ) + + logger.info( + f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) return result async def get_all_labels(self) -> list[str]: