diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 8d5a1a55..18bd6859 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -176,23 +176,6 @@ class Neo4JStorage(BaseGraphStorage): # Noe4J handles persistence automatically pass - def _ensure_label(self, label: str) -> str: - """Ensure a label is valid - - Args: - label: The label to validate - - Returns: - str: The cleaned label - - Raises: - ValueError: If label is empty after cleaning - """ - clean_label = label.strip('"') - if not clean_label: - raise ValueError("Neo4j: Label cannot be empty") - return clean_label - async def has_node(self, node_id: str) -> bool: """ Check if a node with the given label exists in the database @@ -207,19 +190,18 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" - result = await session.run(query) + query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id = node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed return single_result["node_exists"] except Exception as e: logger.error( - f"Error checking node existence for {entity_name_label}: {str(e)}" + f"Error checking node existence for {node_id}: {str(e)}" ) await result.consume() # Ensure results are consumed even on error raise @@ -239,24 +221,21 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If either node_id is invalid Exception: If there is an error executing the query """ - entity_name_label_source = self._ensure_label(source_node_id) - entity_name_label_target = self._ensure_label(target_node_id) - async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: query = ( - f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " + "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " "RETURN COUNT(r) > 0 AS edgeExists" ) - result = await session.run(query) + result = await session.run(query, source_entity_id = source_node_id, target_entity_id = target_node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed return single_result["edgeExists"] except Exception as e: logger.error( - f"Error checking edge existence between {entity_name_label_source} and {entity_name_label_target}: {str(e)}" + f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) await result.consume() # Ensure results are consumed even on error raise @@ -275,13 +254,12 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n" - result = await session.run(query, entity_id=entity_name_label) + query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + result = await session.run(query, entity_id=node_id) try: records = await result.fetch( 2 @@ -289,20 +267,21 @@ class Neo4JStorage(BaseGraphStorage): if len(records) > 1: logger.warning( - f"Multiple nodes found with label '{entity_name_label}'. Using first node." + f"Multiple nodes found with label '{node_id}'. Using first node." ) if records: node = records[0]["n"] node_dict = dict(node) - logger.debug( - f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" - ) + # Remove base label from labels list if it exists + if "labels" in node_dict: + node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + logger.debug(f"Neo4j query node {query} return: {node_dict}") return node_dict return None finally: await result.consume() # Ensure result is fully consumed except Exception as e: - logger.error(f"Error getting node for {entity_name_label}: {str(e)}") + logger.error(f"Error getting node for {node_id}: {str(e)}") raise async def node_degree(self, node_id: str) -> int: @@ -320,42 +299,33 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = self._ensure_label(node_id) - async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = f""" - MATCH (n:`{entity_name_label}`) + query = """ + MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-() - RETURN n, COUNT(r) AS degree + RETURN COUNT(r) AS degree """ - result = await session.run(query) + result = await session.run(query, entity_id = node_id) try: - records = await result.fetch(100) + record = await result.single() - if not records: + if not record: logger.warning( - f"No node found with label '{entity_name_label}'" + f"No node found with label '{node_id}'" ) return 0 - if len(records) > 1: - logger.warning( - f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" - ) - - degree = records[0]["degree"] - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" - ) + degree = record["degree"] + logger.debug("Neo4j query node degree for {node_id} return: {degree}") return degree finally: await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( - f"Error getting node degree for {entity_name_label}: {str(e)}" + f"Error getting node degree for {node_id}: {str(e)}" ) raise @@ -369,11 +339,8 @@ class Neo4JStorage(BaseGraphStorage): Returns: int: Sum of the degrees of both nodes """ - entity_name_label_source = self._ensure_label(src_id) - entity_name_label_target = self._ensure_label(tgt_id) - - src_degree = await self.node_degree(entity_name_label_source) - trg_degree = await self.node_degree(entity_name_label_target) + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) # Convert None to 0 for addition src_degree = 0 if src_degree is None else src_degree @@ -399,24 +366,20 @@ class Neo4JStorage(BaseGraphStorage): Exception: If there is an error executing the query """ try: - entity_name_label_source = self._ensure_label(source_node_id) - entity_name_label_target = self._ensure_label(target_node_id) - async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = f""" - MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) + query = """ + MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) RETURN properties(r) as edge_properties """ - - result = await session.run(query) + result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) try: records = await result.fetch(2) if len(records) > 1: logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." ) if records: try: @@ -433,7 +396,7 @@ class Neo4JStorage(BaseGraphStorage): if key not in edge_result: edge_result[key] = default_value logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"Edge between {source_node_id} and {target_node_id} " f"missing {key}, using default: {default_value}" ) @@ -443,8 +406,8 @@ class Neo4JStorage(BaseGraphStorage): return edge_result except (KeyError, TypeError, ValueError) as e: logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" + f"Error processing edge properties between {source_node_id} " + f"and {target_node_id}: {str(e)}" ) # Return default edge properties on error return { @@ -455,7 +418,7 @@ class Neo4JStorage(BaseGraphStorage): } logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}" ) # Return default edge properties when no edge found return { @@ -488,30 +451,30 @@ class Neo4JStorage(BaseGraphStorage): Exception: If there is an error executing the query """ try: - node_label = self._ensure_label(source_node_id) - - query = f"""MATCH (n:`{node_label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected""" - async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - results = await session.run(query) - edges = [] + query = """MATCH (n:base {entity_id: $entity_id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected""" + results = await session.run(query, entity_id=source_node_id) + edges = [] async for record in results: source_node = record["n"] connected_node = record["connected"] + # Skip if either node is None + if not source_node or not connected_node: + continue + source_label = ( - list(source_node.labels)[0] if source_node.labels else None + source_node.get("entity_id") if source_node.get("entity_id") else None ) target_label = ( - list(connected_node.labels)[0] - if connected_node and connected_node.labels - else None + connected_node.get("entity_id") if connected_node.get("entity_id") else None ) if source_label and target_label: @@ -520,7 +483,7 @@ class Neo4JStorage(BaseGraphStorage): await results.consume() # Ensure results are consumed return edges except Exception as e: - logger.error(f"Error getting edges for node {node_label}: {str(e)}") + logger.error(f"Error getting edges for node {source_node_id}: {str(e)}") await results.consume() # Ensure results are consumed even on error raise except Exception as e: @@ -547,8 +510,9 @@ class Neo4JStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = self._ensure_label(node_id) properties = node_data + entity_type = properties["entity_type"] + entity_id = properties["entity_id"] if "entity_id" not in properties: raise ValueError("Neo4j: node properties must contain an 'entity_id' field") @@ -556,13 +520,14 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - query = f""" - MERGE (n:`{label}` {{entity_id: $properties.entity_id}}) + query = """ + MERGE (n:base {entity_id: $properties.entity_id}) SET n += $properties - """ + SET n:`%s` + """ % entity_type result = await tx.run(query, properties=properties) logger.debug( - f"Upserted node with label '{label}' and properties: {properties}" + f"Upserted node with entity_id '{entity_id}' and properties: {properties}" ) await result.consume() # Ensure result is fully consumed @@ -583,51 +548,6 @@ class Neo4JStorage(BaseGraphStorage): ) ), ) - async def _get_unique_node_entity_id(self, node_label: str) -> str: - """ - Get the entity_id of a node with the given label, ensuring the node is unique. - - Args: - node_label (str): Label of the node to check - - Returns: - str: The entity_id of the unique node - - Raises: - ValueError: If no node with the given label exists or if multiple nodes have the same label - """ - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - query = f""" - MATCH (n:`{node_label}`) - RETURN n, count(n) as node_count - """ - result = await session.run(query) - try: - records = await result.fetch( - 2 - ) # We only need to know if there are 0, 1, or >1 nodes - - if not records or records[0]["node_count"] == 0: - raise ValueError( - f"Neo4j: node with label '{node_label}' does not exist" - ) - - if records[0]["node_count"] > 1: - raise ValueError( - f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node" - ) - - node = records[0]["n"] - if "entity_id" not in node: - raise ValueError( - f"Neo4j: node with label '{node_label}' does not have an entity_id property" - ) - - return node["entity_id"] - finally: - await result.consume() # Ensure result is fully consumed @retry( stop=stop_after_attempt(3), @@ -657,38 +577,30 @@ class Neo4JStorage(BaseGraphStorage): Raises: ValueError: If either source or target node does not exist or is not unique """ - source_label = self._ensure_label(source_node_id) - target_label = self._ensure_label(target_node_id) - edge_properties = edge_data - - # Get entity_ids for source and target nodes, ensuring they are unique - source_entity_id = await self._get_unique_node_entity_id(source_label) - target_entity_id = await self._get_unique_node_entity_id(target_label) - try: + edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - query = f""" - MATCH (source:`{source_label}` {{entity_id: $source_entity_id}}) + query = """ + MATCH (source:base {entity_id: $source_entity_id}) WITH source - MATCH (target:`{target_label}` {{entity_id: $target_entity_id}}) + MATCH (target:base {entity_id: $target_entity_id}) MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r, source, target """ result = await tx.run( query, - source_entity_id=source_entity_id, - target_entity_id=target_entity_id, + source_entity_id=source_node_id, + target_entity_id=target_node_id, properties=edge_properties, ) try: - records = await result.fetch(100) + records = await result.fetch(2) if records: logger.debug( - f"Upserted edge from '{source_label}' (entity_id: {source_entity_id}) " - f"to '{target_label}' (entity_id: {target_entity_id}) " + f"Upserted edge from '{source_node_id}' to '{target_node_id}'" f"with properties: {edge_properties}" ) finally: @@ -726,7 +638,6 @@ class Neo4JStorage(BaseGraphStorage): Returns: KnowledgeGraph: Complete connected subgraph for specified node """ - label = node_label.strip('"') result = KnowledgeGraph() seen_nodes = set() seen_edges = set() @@ -735,7 +646,7 @@ class Neo4JStorage(BaseGraphStorage): database=self._DATABASE, default_access_mode="READ" ) as session: try: - if label == "*": + if node_label == "*": main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() @@ -760,12 +671,11 @@ class Neo4JStorage(BaseGraphStorage): # Main query uses partial matching main_query = """ MATCH (start) - WHERE any(label IN labels(start) WHERE + WHERE CASE - WHEN $inclusive THEN label CONTAINS $label - ELSE label = $label + WHEN $inclusive THEN start.entity_id CONTAINS $entity_id + ELSE start.entity_id = $entity_id END - ) WITH start CALL apoc.path.subgraphAll(start, { relationshipFilter: '', @@ -799,7 +709,7 @@ class Neo4JStorage(BaseGraphStorage): main_query, { "max_nodes": MAX_GRAPH_NODES, - "label": label, + "entity_id": node_label, "inclusive": inclusive, "max_depth": max_depth, "min_degree": min_degree, @@ -818,7 +728,7 @@ class Neo4JStorage(BaseGraphStorage): result.nodes.append( KnowledgeGraphNode( id=f"{node_id}", - labels=list(node.labels), + labels=[label for label in node.labels if label != "base"], properties=dict(node), ) ) @@ -849,7 +759,7 @@ class Neo4JStorage(BaseGraphStorage): except neo4jExceptions.ClientError as e: logger.warning(f"APOC plugin error: {str(e)}") - if label != "*": + if node_label != "*": logger.warning( "Neo4j: falling back to basic Cypher recursive search..." ) @@ -857,12 +767,12 @@ class Neo4JStorage(BaseGraphStorage): logger.warning( "Neo4j: inclusive search mode is not supported in recursive query, using exact matching" ) - return await self._robust_fallback(label, max_depth, min_degree) + return await self._robust_fallback(node_label, max_depth, min_degree) return result async def _robust_fallback( - self, label: str, max_depth: int, min_degree: int = 0 + self, node_label: str, max_depth: int, min_degree: int = 0 ) -> KnowledgeGraph: """ Fallback implementation when APOC plugin is not available or incompatible. @@ -895,12 +805,11 @@ class Neo4JStorage(BaseGraphStorage): database=self._DATABASE, default_access_mode="READ" ) as session: query = """ - MATCH (a)-[r]-(b) - WHERE id(a) = toInteger($node_id) + MATCH (a:base {entity_id: $entity_id})-[r]-(b) WITH r, b, id(r) as edge_id, id(b) as target_id RETURN r, b, edge_id, target_id """ - results = await session.run(query, {"node_id": node.id}) + results = await session.run(query, entity_id=node.id) # Get all records and release database connection records = await results.fetch( @@ -928,14 +837,14 @@ class Neo4JStorage(BaseGraphStorage): edge_id = str(record["edge_id"]) if edge_id not in visited_edges: b_node = record["b"] - target_id = str(record["target_id"]) + target_id = b_node.get("entity_id") - if b_node.labels: # Only process if target node has labels + if target_id: # Only process if target node has entity_id # Create KnowledgeGraphNode for target target_node = KnowledgeGraphNode( id=f"{target_id}", - labels=list(b_node.labels), - properties=dict(b_node), + labels=[label for label in b_node.labels if label != "base"], + properties=dict(b_node.properties), ) # Create KnowledgeGraphEdge @@ -961,11 +870,11 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = f""" - MATCH (n:`{label}`) + query = """ + MATCH (n:base {entity_id: $entity_id}) RETURN id(n) as node_id, n """ - node_result = await session.run(query) + node_result = await session.run(query, entity_id=node_label) try: node_record = await node_result.single() if not node_record: @@ -973,9 +882,9 @@ class Neo4JStorage(BaseGraphStorage): # Create initial KnowledgeGraphNode start_node = KnowledgeGraphNode( - id=f"{node_record['node_id']}", - labels=list(node_record["n"].labels), - properties=dict(node_record["n"]), + id=f"{node_record['n'].get('entity_id')}", + labels=[label for label in node_record["n"].labels if label != "base"], + properties=dict(node_record["n"].properties), ) finally: await node_result.consume() # Ensure results are consumed @@ -999,11 +908,10 @@ class Neo4JStorage(BaseGraphStorage): # Method 2: Query compatible with older versions query = """ - MATCH (n) - WITH DISTINCT labels(n) AS node_labels - UNWIND node_labels AS label - RETURN DISTINCT label - ORDER BY label + MATCH (n) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label """ result = await session.run(query) labels = [] @@ -1034,15 +942,13 @@ class Neo4JStorage(BaseGraphStorage): Args: node_id: The label of the node to delete """ - label = self._ensure_label(node_id) - async def _do_delete(tx: AsyncManagedTransaction): - query = f""" - MATCH (n:`{label}`) + query = """ + MATCH (n:base {entity_id: $entity_id}) DETACH DELETE n """ - result = await tx.run(query) - logger.debug(f"Deleted node with label '{label}'") + result = await tx.run(query, entity_id=node_id) + logger.debug(f"Deleted node with label '{node_id}'") await result.consume() # Ensure result is fully consumed try: @@ -1092,16 +998,13 @@ class Neo4JStorage(BaseGraphStorage): edges: List of edges to be deleted, each edge is a (source, target) tuple """ for source, target in edges: - source_label = self._ensure_label(source) - target_label = self._ensure_label(target) - async def _do_delete_edge(tx: AsyncManagedTransaction): - query = f""" - MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) + query = """ + MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) DELETE r """ - result = await tx.run(query) - logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + result = await tx.run(query, source_entity_id=source, target_entity_id=target) + logger.debug(f"Deleted edge from '{source}' to '{target}'") await result.consume() # Ensure result is fully consumed try: