diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 18bd6859..d0c6c779 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -195,14 +195,12 @@ class Neo4JStorage(BaseGraphStorage): ) as session: try: query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" - result = await session.run(query, entity_id = node_id) + result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed return single_result["node_exists"] except Exception as e: - logger.error( - f"Error checking node existence for {node_id}: {str(e)}" - ) + logger.error(f"Error checking node existence for {node_id}: {str(e)}") await result.consume() # Ensure results are consumed even on error raise @@ -229,7 +227,11 @@ class Neo4JStorage(BaseGraphStorage): "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " "RETURN COUNT(r) > 0 AS edgeExists" ) - result = await session.run(query, source_entity_id = source_node_id, target_entity_id = target_node_id) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) single_result = await result.single() await result.consume() # Ensure result is fully consumed return single_result["edgeExists"] @@ -274,7 +276,11 @@ class Neo4JStorage(BaseGraphStorage): node_dict = dict(node) # Remove base label from labels list if it exists if "labels" in node_dict: - node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + node_dict["labels"] = [ + label + for label in node_dict["labels"] + if label != "base" + ] logger.debug(f"Neo4j query node {query} return: {node_dict}") return node_dict return None @@ -308,25 +314,23 @@ class Neo4JStorage(BaseGraphStorage): OPTIONAL MATCH (n)-[r]-() RETURN COUNT(r) AS degree """ - result = await session.run(query, entity_id = node_id) + result = await session.run(query, entity_id=node_id) try: record = await result.single() if not record: - logger.warning( - f"No node found with label '{node_id}'" - ) + logger.warning(f"No node found with label '{node_id}'") return 0 degree = record["degree"] - logger.debug("Neo4j query node degree for {node_id} return: {degree}") + logger.debug( + "Neo4j query node degree for {node_id} return: {degree}" + ) return degree finally: await result.consume() # Ensure result is fully consumed except Exception as e: - logger.error( - f"Error getting node degree for {node_id}: {str(e)}" - ) + logger.error(f"Error getting node degree for {node_id}: {str(e)}") raise async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -373,7 +377,11 @@ class Neo4JStorage(BaseGraphStorage): MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) RETURN properties(r) as edge_properties """ - result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) try: records = await result.fetch(2) @@ -471,10 +479,14 @@ class Neo4JStorage(BaseGraphStorage): continue source_label = ( - source_node.get("entity_id") if source_node.get("entity_id") else None + source_node.get("entity_id") + if source_node.get("entity_id") + else None ) target_label = ( - connected_node.get("entity_id") if connected_node.get("entity_id") else None + connected_node.get("entity_id") + if connected_node.get("entity_id") + else None ) if source_label and target_label: @@ -483,7 +495,9 @@ class Neo4JStorage(BaseGraphStorage): await results.consume() # Ensure results are consumed return edges except Exception as e: - logger.error(f"Error getting edges for node {source_node_id}: {str(e)}") + logger.error( + f"Error getting edges for node {source_node_id}: {str(e)}" + ) await results.consume() # Ensure results are consumed even on error raise except Exception as e: @@ -520,11 +534,14 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - query = """ + query = ( + """ MERGE (n:base {entity_id: $properties.entity_id}) SET n += $properties SET n:`%s` - """ % entity_type + """ + % entity_type + ) result = await tx.run(query, properties=properties) logger.debug( f"Upserted node with entity_id '{entity_id}' and properties: {properties}" @@ -548,7 +565,6 @@ class Neo4JStorage(BaseGraphStorage): ) ), ) - @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -728,7 +744,11 @@ class Neo4JStorage(BaseGraphStorage): result.nodes.append( KnowledgeGraphNode( id=f"{node_id}", - labels=[label for label in node.labels if label != "base"], + labels=[ + label + for label in node.labels + if label != "base" + ], properties=dict(node), ) ) @@ -767,7 +787,9 @@ class Neo4JStorage(BaseGraphStorage): logger.warning( "Neo4j: inclusive search mode is not supported in recursive query, using exact matching" ) - return await self._robust_fallback(node_label, max_depth, min_degree) + return await self._robust_fallback( + node_label, max_depth, min_degree + ) return result @@ -843,7 +865,9 @@ class Neo4JStorage(BaseGraphStorage): # Create KnowledgeGraphNode for target target_node = KnowledgeGraphNode( id=f"{target_id}", - labels=[label for label in b_node.labels if label != "base"], + labels=[ + label for label in b_node.labels if label != "base" + ], properties=dict(b_node.properties), ) @@ -883,7 +907,9 @@ class Neo4JStorage(BaseGraphStorage): # Create initial KnowledgeGraphNode start_node = KnowledgeGraphNode( id=f"{node_record['n'].get('entity_id')}", - labels=[label for label in node_record["n"].labels if label != "base"], + labels=[ + label for label in node_record["n"].labels if label != "base" + ], properties=dict(node_record["n"].properties), ) finally: @@ -942,6 +968,7 @@ class Neo4JStorage(BaseGraphStorage): Args: node_id: The label of the node to delete """ + async def _do_delete(tx: AsyncManagedTransaction): query = """ MATCH (n:base {entity_id: $entity_id}) @@ -998,12 +1025,15 @@ class Neo4JStorage(BaseGraphStorage): edges: List of edges to be deleted, each edge is a (source, target) tuple """ for source, target in edges: + async def _do_delete_edge(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) DELETE r """ - result = await tx.run(query, source_entity_id=source, target_entity_id=target) + result = await tx.run( + query, source_entity_id=source, target_entity_id=target + ) logger.debug(f"Deleted edge from '{source}' to '{target}'") await result.consume() # Ensure result is fully consumed