diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f6567249..ea316d0f 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -64,19 +64,19 @@ class Neo4JStorage(BaseGraphStorage): MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800 + config.get("neo4j", "connection_pool_size", fallback=50), ) ) CONNECTION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_TIMEOUT", - config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0 + config.get("neo4j", "connection_timeout", fallback=30.0), ), ) CONNECTION_ACQUISITION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", - config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0 + config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), ), ) MAX_TRANSACTION_RETRY_TIME = float( @@ -188,23 +188,24 @@ class Neo4JStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: entity_name_label = await self._ensure_label(node_id) - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) result = await session.run(query) single_result = await result.single() await result.consume() # Ensure result is fully consumed - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" - ) return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" @@ -212,9 +213,6 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) single_result = await result.single() await result.consume() # Ensure result is fully consumed - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" - ) return single_result["edgeExists"] async def get_node(self, node_id: str) -> dict[str, str] | None: @@ -227,14 +225,20 @@ class Neo4JStorage(BaseGraphStorage): dict: Node properties if found None: If node not found """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates await result.consume() # Ensure result is fully consumed if len(records) > 1: - logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.") + logger.warning( + f"Multiple nodes found with label '{entity_name_label}'. Using first node." + ) if records: node = records[0]["n"] node_dict = dict(node) @@ -248,16 +252,18 @@ class Neo4JStorage(BaseGraphStorage): """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. If no node is found, returns 0. - + Args: node_id: The label of the node - + Returns: int: The number of relationships the node has, or 0 if no node found """ entity_name_label = node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = f""" MATCH (n:`{entity_name_label}`) OPTIONAL MATCH (n)-[r]-() @@ -266,14 +272,16 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) records = await result.fetch(100) await result.consume() # Ensure result is fully consumed - + if not records: logger.warning(f"No node found with label '{entity_name_label}'") return 0 - + if len(records) > 1: - logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree") - + logger.warning( + f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + ) + degree = records[0]["degree"] logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" @@ -296,30 +304,6 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees - async def check_duplicate_nodes(self) -> list[tuple[str, int]]: - """Find all labels that have multiple nodes - - Returns: - list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes - """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: - query = """ - MATCH (n) - WITH labels(n) as nodeLabels - UNWIND nodeLabels as label - WITH label, count(*) as node_count - WHERE node_count > 1 - RETURN label, node_count - ORDER BY node_count DESC - """ - result = await session.run(query) - duplicates = [] - async for record in result: - label = record["label"] - count = record["node_count"] - logger.info(f"Found {count} nodes with label: {label}") - duplicates.append((label, count)) - return duplicates async def get_edge( self, source_node_id: str, target_node_id: str @@ -328,64 +312,69 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = f""" MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties """ result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + try: + records = await result.fetch(2) # Get up to 2 records to check for duplicates + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: + try: + result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in result: + result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + ) + return result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" ) - if records: - try: - result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in result: - result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" - ) - return result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" - ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + # Return default edge properties when no edge found + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + finally: + await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( @@ -409,7 +398,9 @@ class Neo4JStorage(BaseGraphStorage): query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: results = await session.run(query) edges = [] try: @@ -429,7 +420,9 @@ class Neo4JStorage(BaseGraphStorage): if source_label and target_label: edges.append((source_label, target_label)) finally: - await results.consume() # Ensure results are consumed even if processing fails + await ( + results.consume() + ) # Ensure results are consumed even if processing fails return edges @@ -461,10 +454,11 @@ class Neo4JStorage(BaseGraphStorage): MERGE (n:`{label}`) SET n += $properties """ - await tx.run(query, properties=properties) + result = await tx.run(query, properties=properties) logger.debug( f"Upserted node with label '{label}' and properties: {properties}" ) + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -509,9 +503,13 @@ class Neo4JStorage(BaseGraphStorage): target_exists = await self.has_node(target_label) if not source_exists: - raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist") + raise ValueError( + f"Neo4j: source node with label '{source_label}' does not exist" + ) if not target_exists: - raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist") + raise ValueError( + f"Neo4j: target node with label '{target_label}' does not exist" + ) async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" @@ -570,7 +568,9 @@ class Neo4JStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: try: if label == "*": main_query = """ @@ -728,11 +728,9 @@ class Neo4JStorage(BaseGraphStorage): visited_nodes.add(node_id) # Add node data with label as ID - result["nodes"].append({ - "id": current_label, - "labels": current_label, - "properties": node - }) + result["nodes"].append( + {"id": current_label, "labels": current_label, "properties": node} + ) # Get connected nodes that meet the degree requirement # Note: We don't need to check a's degree since it's the current node @@ -744,7 +742,9 @@ class Neo4JStorage(BaseGraphStorage): WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) RETURN r, b """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: results = await session.run(query, {"min_degree": min_degree}) async for record in results: # Handle edges @@ -754,19 +754,23 @@ class Neo4JStorage(BaseGraphStorage): b_node = record["b"] if b_node.labels: # Only process if target node has labels target_label = list(b_node.labels)[0] - result["edges"].append({ - "id": f"{current_label}_{target_label}", - "type": rel.type, - "source": current_label, - "target": target_label, - "properties": dict(rel) - }) + result["edges"].append( + { + "id": f"{current_label}_{target_label}", + "type": rel.type, + "source": current_label, + "target": target_label, + "properties": dict(rel), + } + ) visited_edges.add(edge_id) # Continue traversal await traverse(target_label, current_depth + 1) else: - logger.warning(f"Skipping edge {edge_id} due to missing labels on target node") + logger.warning( + f"Skipping edge {edge_id} due to missing labels on target node" + ) await traverse(label, 0) return result @@ -777,7 +781,9 @@ class Neo4JStorage(BaseGraphStorage): Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: # Method 1: Direct metadata query (Available for Neo4j 4.3+) # query = "CALL db.labels() YIELD label RETURN label" @@ -796,7 +802,9 @@ class Neo4JStorage(BaseGraphStorage): async for record in result: labels.append(record["label"]) finally: - await result.consume() # Ensure results are consumed even if processing fails + await ( + result.consume() + ) # Ensure results are consumed even if processing fails return labels @retry( @@ -824,8 +832,9 @@ class Neo4JStorage(BaseGraphStorage): MATCH (n:`{label}`) DETACH DELETE n """ - await tx.run(query) + result = await tx.run(query) logger.debug(f"Deleted node with label '{label}'") + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -882,8 +891,9 @@ class Neo4JStorage(BaseGraphStorage): MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) DELETE r """ - await tx.run(query) + result = await tx.run(query) logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: