diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 3c7e57a7..9fdfb3bb 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -658,7 +658,8 @@ class Neo4JStorage(BaseGraphStorage): max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 Returns: - KnowledgeGraph object containing nodes and edges + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit """ result = KnowledgeGraph() seen_nodes = set() @@ -669,6 +670,23 @@ class Neo4JStorage(BaseGraphStorage): ) as session: try: if node_label == "*": + # First check total node count to determine if graph is truncated + count_query = "MATCH (n) RETURN count(n) as total" + count_result = None + try: + count_result = await session.run(count_query) + count_record = await count_result.single() + + if count_record and count_record["total"] > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + ) + finally: + if count_result: + await count_result.consume() + + # Run main query to get nodes with highest degree main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() @@ -683,14 +701,20 @@ class Neo4JStorage(BaseGraphStorage): RETURN filtered_nodes AS node_info, collect(DISTINCT r) AS relationships """ - result_set = await session.run( - main_query, - {"max_nodes": max_nodes}, - ) + result_set = None + try: + result_set = await session.run( + main_query, + {"max_nodes": max_nodes}, + ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() else: - # Main query uses partial matching - main_query = """ + # First try without limit to check if we need to truncate + full_query = """ MATCH (start) WHERE start.entity_id = $entity_id WITH start @@ -698,63 +722,118 @@ class Neo4JStorage(BaseGraphStorage): relationshipFilter: '', minLevel: 0, maxLevel: $max_depth, - limit: $max_nodes, bfs: true }) YIELD nodes, relationships + WITH nodes, relationships, size(nodes) AS total_nodes UNWIND nodes AS node - WITH collect({node: node}) AS node_info, relationships - RETURN node_info, relationships + WITH collect({node: node}) AS node_info, relationships, total_nodes + RETURN node_info, relationships, total_nodes """ - result_set = await session.run( - main_query, - { - "entity_id": node_label, - "max_depth": max_depth, - "max_nodes": max_nodes, - }, - ) - try: - record = await result_set.single() - - if record: - # Handle nodes (compatible with multi-label cases) - for node_info in record["node_info"]: - node = node_info["node"] - node_id = node.id - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=[node.get("entity_id")], - properties=dict(node), - ) - ) - seen_nodes.add(node_id) - - # Handle relationships (including direction information) - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), - ) - ) - seen_edges.add(edge_id) - - logger.info( - f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges" + # Try to get full result + full_result = None + try: + full_result = await session.run( + full_query, + { + "entity_id": node_label, + "max_depth": max_depth, + }, ) - finally: - await result_set.consume() # Ensure result set is consumed + full_record = await full_result.single() + + # If no record found, return empty KnowledgeGraph + if not full_record: + logger.debug(f"No nodes found for entity_id: {node_label}") + return result + + # If record found, check node count + total_nodes = full_record["total_nodes"] + + if total_nodes <= max_nodes: + # If node count is within limit, use full result directly + logger.debug( + f"Using full result with {total_nodes} nodes (no truncation needed)" + ) + record = full_record + else: + # If node count exceeds limit, set truncated flag and run limited query + result.is_truncated = True + logger.info( + f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" + ) + + # Run limited query + limited_query = """ + MATCH (start) + WHERE start.entity_id = $entity_id + WITH start + CALL apoc.path.subgraphAll(start, { + relationshipFilter: '', + minLevel: 0, + maxLevel: $max_depth, + limit: $max_nodes, + bfs: true + }) + YIELD nodes, relationships + UNWIND nodes AS node + WITH collect({node: node}) AS node_info, relationships + RETURN node_info, relationships + """ + result_set = None + try: + result_set = await session.run( + limited_query, + { + "entity_id": node_label, + "max_depth": max_depth, + "max_nodes": max_nodes, + }, + ) + record = await result_set.single() + finally: + if result_set: + await result_set.consume() + finally: + if full_result: + await full_result.consume() + + if record: + # Handle nodes (compatible with multi-label cases) + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) + ) + seen_nodes.add(node_id) + + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges" + ) except neo4jExceptions.ClientError as e: logger.warning(f"APOC plugin error: {str(e)}") @@ -763,6 +842,10 @@ class Neo4JStorage(BaseGraphStorage): "Neo4j: falling back to basic Cypher recursive search..." ) return await self._robust_fallback(node_label, max_depth, max_nodes) + else: + logger.warning( + "Neo4j: APOC plugin error with wildcard query, returning empty result" + ) return result @@ -788,7 +871,11 @@ class Neo4JStorage(BaseGraphStorage): logger.debug(f"Reached max depth: {max_depth}") return if len(visited_nodes) >= max_nodes: - logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") + # Set truncated flag when we hit the max_nodes limit + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to: {max_nodes} nodes" + ) return # Check if node already visited