diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 60e8982e..082b4bf2 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -163,13 +163,14 @@ class Neo4JStorage(BaseGraphStorage): } async def close(self): + """Close the Neo4j driver and release all resources""" if self._driver: await self._driver.close() self._driver = None async def __aexit__(self, exc_type, exc, tb): - if self._driver: - await self._driver.close() + """Ensure driver is closed when context manager exits""" + await self.close() async def index_done_callback(self) -> None: # Noe4J handles persistence automatically @@ -187,33 +188,72 @@ class Neo4JStorage(BaseGraphStorage): return clean_label async def has_node(self, node_id: str) -> bool: + """ + Check if a node with the given label exists in the database + + Args: + node_id: Label of the node to check + + Returns: + bool: True if node exists, False otherwise + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query + """ entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" - ) - result = await session.run(query) - single_result = await result.single() - await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] + try: + query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" + result = await session.run(query) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["node_exists"] + except Exception as e: + logger.error( + f"Error checking node existence for {entity_name_label}: {str(e)}" + ) + await result.consume() # Ensure results are consumed even on error + raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + """ + Check if an edge exists between two nodes + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + bool: True if edge exists, False otherwise + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ + entity_name_label_source = await self._ensure_label(source_node_id) + entity_name_label_target = await self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " - "RETURN COUNT(r) > 0 AS edgeExists" - ) - result = await session.run(query) - single_result = await result.single() - await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] + try: + query = ( + f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run(query) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["edgeExists"] + except Exception as e: + logger.error( + f"Error checking edge existence between {entity_name_label_source} and {entity_name_label_target}: {str(e)}" + ) + await result.consume() # Ensure results are consumed even on error + raise async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier. @@ -224,29 +264,40 @@ class Neo4JStorage(BaseGraphStorage): Returns: dict: Node properties if found None: If node not found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query """ + entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - entity_name_label = await self._ensure_label(node_id) - query = f"MATCH (n:`{entity_name_label}`) RETURN n" - result = await session.run(query) - records = await result.fetch( - 2 - ) # Get up to 2 records to check for duplicates - await result.consume() # Ensure result is fully consumed - if len(records) > 1: - logger.warning( - f"Multiple nodes found with label '{entity_name_label}'. Using first node." - ) - if records: - node = records[0]["n"] - node_dict = dict(node) - logger.debug( - f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" - ) - return node_dict - return None + try: + query = f"MATCH (n:`{entity_name_label}`) RETURN n" + result = await session.run(query) + try: + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{entity_name_label}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + logger.debug( + f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" + ) + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {entity_name_label}: {str(e)}") + raise async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. @@ -258,39 +309,63 @@ class Neo4JStorage(BaseGraphStorage): Returns: int: The number of relationships the node has, or 0 if no node found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query """ - entity_name_label = node_id.strip('"') + entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = f""" - MATCH (n:`{entity_name_label}`) - OPTIONAL MATCH (n)-[r]-() - RETURN n, COUNT(r) AS degree - """ - result = await session.run(query) - records = await result.fetch(100) - await result.consume() # Ensure result is fully consumed + try: + query = f""" + MATCH (n:`{entity_name_label}`) + OPTIONAL MATCH (n)-[r]-() + RETURN n, COUNT(r) AS degree + """ + result = await session.run(query) + try: + records = await result.fetch(100) - if not records: - logger.warning(f"No node found with label '{entity_name_label}'") - return 0 + if not records: + logger.warning( + f"No node found with label '{entity_name_label}'" + ) + return 0 - if len(records) > 1: - logger.warning( - f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + if len(records) > 1: + logger.warning( + f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + ) + + degree = records[0]["degree"] + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" + ) + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error( + f"Error getting node degree for {entity_name_label}: {str(e)}" ) - - degree = records[0]["degree"] - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" - ) - return degree + raise async def edge_degree(self, src_id: str, tgt_id: str) -> int: - entity_name_label_source = src_id.strip('"') - entity_name_label_target = tgt_id.strip('"') + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ + entity_name_label_source = await self._ensure_label(src_id) + entity_name_label_target = await self._ensure_label(tgt_id) + src_degree = await self.node_degree(entity_name_label_source) trg_degree = await self.node_degree(entity_name_label_target) @@ -299,17 +374,27 @@ class Neo4JStorage(BaseGraphStorage): trg_degree = 0 if trg_degree is None else trg_degree degrees = int(src_degree) + int(trg_degree) - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}" - ) return degrees async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ try: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + entity_name_label_source = await self._ensure_label(source_node_id) + entity_name_label_target = await self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -320,109 +405,123 @@ class Neo4JStorage(BaseGraphStorage): """ result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates - await result.consume() # Ensure result is fully consumed before processing records - - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + try: + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates + + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: + try: + edge_result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {edge_result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" + ) + return edge_result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" ) - if records: - try: - edge_result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {edge_result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in edge_result: - edge_result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" - ) - return edge_result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" - ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + # Return default edge properties when no edge found + return { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + finally: + await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - node_label = source_node_id.strip('"') + """Retrieves all edges (relationships) for a particular node identified by its label. + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + ValueError: If source_node_id is invalid + Exception: If there is an error executing the query """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information - """ - query = f"""MATCH (n:`{node_label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected""" - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - results = await session.run(query) - edges = [] - try: - async for record in results: - source_node = record["n"] - connected_node = record["connected"] + try: + node_label = await self._ensure_label(source_node_id) - source_label = ( - list(source_node.labels)[0] if source_node.labels else None - ) - target_label = ( - list(connected_node.labels)[0] - if connected_node and connected_node.labels - else None - ) + query = f"""MATCH (n:`{node_label}`) + OPTIONAL MATCH (n)-[r]-(connected) + RETURN n, r, connected""" - if source_label and target_label: - edges.append((source_label, target_label)) - finally: - await ( - results.consume() - ) # Ensure results are consumed even if processing fails + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + results = await session.run(query) + edges = [] - return edges + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + + source_label = ( + list(source_node.labels)[0] if source_node.labels else None + ) + target_label = ( + list(connected_node.labels)[0] + if connected_node and connected_node.labels + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges if edges else None + except Exception as e: + logger.error(f"Error getting edges for node {node_label}: {str(e)}") + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise @retry( stop=stop_after_attempt(3), @@ -838,7 +937,6 @@ class Neo4JStorage(BaseGraphStorage): RETURN DISTINCT label ORDER BY label """ - result = await session.run(query) labels = [] try: