diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index ea316d0f..60e8982e 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,7 +3,7 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, List, Dict, final +from typing import Any, final, Optional import numpy as np import configparser @@ -304,7 +304,6 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees - async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -321,60 +320,59 @@ class Neo4JStorage(BaseGraphStorage): """ result = await session.run(query) - try: - records = await result.fetch(2) # Get up to 2 records to check for duplicates - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." - ) - if records: - try: - result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in result: - result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" - ) - return result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + records = await result.fetch(2) # Get up to 2 records to check for duplicates + await result.consume() # Ensure result is fully consumed before processing records + + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - finally: - await result.consume() # Ensure result is fully consumed + if records: + try: + edge_result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {edge_result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" + ) + return edge_result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + ) + # Return default edge properties when no edge found + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } except Exception as e: logger.error( @@ -685,30 +683,36 @@ class Neo4JStorage(BaseGraphStorage): await result_set.consume() # Ensure result set is consumed except neo4jExceptions.ClientError as e: - logger.warning( - f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation" - ) - if inclusive: + logger.warning(f"APOC plugin error: {str(e)}") + if label != "*": logger.warning( - "Inclusive search mode is not supported in recursive query, using exact matching" + "Neo4j: falling back to basic Cypher recursive search..." ) - return await self._robust_fallback(label, max_depth, min_degree) + if inclusive: + logger.warning( + "Neo4j: inclusive search mode is not supported in recursive query, using exact matching" + ) + return await self._robust_fallback(label, max_depth, min_degree) return result async def _robust_fallback( self, label: str, max_depth: int, min_degree: int = 0 - ) -> Dict[str, List[Dict]]: + ) -> KnowledgeGraph: """ Fallback implementation when APOC plugin is not available or incompatible. This method implements the same functionality as get_knowledge_graph but uses only basic Cypher queries and recursive traversal instead of APOC procedures. """ - result = {"nodes": [], "edges": []} + result = KnowledgeGraph() visited_nodes = set() visited_edges = set() - async def traverse(current_label: str, current_depth: int): + async def traverse( + node: KnowledgeGraphNode, + edge: Optional[KnowledgeGraphEdge], + current_depth: int, + ): # Check traversal limits if current_depth > max_depth: logger.debug(f"Reached max depth: {max_depth}") @@ -717,62 +721,101 @@ class Neo4JStorage(BaseGraphStorage): logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") return - # Get current node details - node = await self.get_node(current_label) - if not node: + # Check if node already visited + if node.id in visited_nodes: return - node_id = f"{current_label}" - if node_id in visited_nodes: - return - visited_nodes.add(node_id) - - # Add node data with label as ID - result["nodes"].append( - {"id": current_label, "labels": current_label, "properties": node} - ) - - # Get connected nodes that meet the degree requirement - # Note: We don't need to check a's degree since it's the current node - # and was already validated in the previous iteration - query = f""" - MATCH (a:`{current_label}`)-[r]-(b) - WITH r, b, - COUNT((b)--()) AS b_degree - WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) - RETURN r, b - """ + # Get all edges and target nodes async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - results = await session.run(query, {"min_degree": min_degree}) - async for record in results: - # Handle edges + query = """ + MATCH (a)-[r]-(b) + WHERE id(a) = toInteger($node_id) + WITH r, b, id(r) as edge_id, id(b) as target_id + RETURN r, b, edge_id, target_id + """ + results = await session.run(query, {"node_id": node.id}) + + # Get all records and release database connection + records = await results.fetch() + await results.consume() # Ensure results are consumed + + # Nodes not connected to start node need to check degree + if current_depth > 1 and len(records) < min_degree: + return + + # Add current node to result + result.nodes.append(node) + visited_nodes.add(node.id) + + # Add edge to result if it exists and not already added + if edge and edge.id not in visited_edges: + result.edges.append(edge) + visited_edges.add(edge.id) + + # Prepare nodes and edges for recursive processing + nodes_to_process = [] + for record in records: rel = record["r"] - edge_id = f"{rel.id}_{rel.type}" + edge_id = str(record["edge_id"]) if edge_id not in visited_edges: b_node = record["b"] - if b_node.labels: # Only process if target node has labels - target_label = list(b_node.labels)[0] - result["edges"].append( - { - "id": f"{current_label}_{target_label}", - "type": rel.type, - "source": current_label, - "target": target_label, - "properties": dict(rel), - } - ) - visited_edges.add(edge_id) + target_id = str(record["target_id"]) - # Continue traversal - await traverse(target_label, current_depth + 1) + if b_node.labels: # Only process if target node has labels + # Create KnowledgeGraphNode for target + target_node = KnowledgeGraphNode( + id=target_id, + labels=list(b_node.labels), + properties=dict(b_node), + ) + + # Create KnowledgeGraphEdge + target_edge = KnowledgeGraphEdge( + id=edge_id, + type=rel.type, + source=node.id, + target=target_id, + properties=dict(rel), + ) + + nodes_to_process.append((target_node, target_edge)) else: logger.warning( f"Skipping edge {edge_id} due to missing labels on target node" ) - await traverse(label, 0) + # Process nodes after releasing database connection + for target_node, target_edge in nodes_to_process: + await traverse(target_node, target_edge, current_depth + 1) + + # Get the starting node's data + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (n:`{label}`) + RETURN id(n) as node_id, n + """ + node_result = await session.run(query) + try: + node_record = await node_result.single() + if not node_record: + return result + + # Create initial KnowledgeGraphNode + start_node = KnowledgeGraphNode( + id=str(node_record["node_id"]), + labels=list(node_record["n"].labels), + properties=dict(node_record["n"]), + ) + finally: + await node_result.consume() # Ensure results are consumed + + # Start traversal with the initial node + await traverse(start_node, None, 0) + return result async def get_all_labels(self) -> list[str]: