From cd81312659630cde494b34bf26f73720187f80fc Mon Sep 17 00:00:00 2001 From: Pankaj Kaushal Date: Fri, 14 Feb 2025 16:04:06 +0100 Subject: [PATCH] Enhance Neo4j graph storage with error handling and label validation - Add label existence check and validation methods in Neo4j implementation - Improve error handling in get_node, get_edge, and upsert methods - Add default values and logging for missing edge properties - Ensure consistent label processing across graph storage methods --- lightrag/kg/neo4j_impl.py | 134 ++++++++++++++++++++++++++++---------- lightrag/operate.py | 62 ++++++++++++++---- 2 files changed, 150 insertions(+), 46 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index e9a53110..15525375 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage): async def index_done_callback(self): print("KG successfully indexed.") - async def has_node(self, node_id: str) -> bool: - entity_name_label = node_id.strip('"') + async def _label_exists(self, label: str) -> bool: + """Check if a label exists in the Neo4j database.""" + query = "CALL db.labels() YIELD label RETURN label" + try: + async with self._driver.session(database=self._DATABASE) as session: + result = await session.run(query) + labels = [record["label"] for record in await result.data()] + return label in labels + except Exception as e: + logger.error(f"Error checking label existence: {e}") + return False + async def _ensure_label(self, label: str) -> str: + """Ensure a label exists by validating it.""" + clean_label = label.strip('"') + if not await self._label_exists(clean_label): + logger.warning(f"Label '{clean_label}' does not exist in Neo4j") + return clean_label + + async def has_node(self, node_id: str) -> bool: + entity_name_label = await self._ensure_label(node_id) async with self._driver.session(database=self._DATABASE) as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" @@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage): return single_result["edgeExists"] async def get_node(self, node_id: str) -> Union[dict, None]: + """Get node by its label identifier. + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + """ async with self._driver.session(database=self._DATABASE) as session: - entity_name_label = node_id.strip('"') + entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) record = await result.single() @@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') - """ - Find all edges between nodes of two given labels + """Find edge between two nodes identified by their labels. Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes + source_node_id (str): Label of the source node + target_node_id (str): Label of the target node Returns: - list: List of all relationships/edges found + dict: Edge properties if found, with at least {"weight": 0.0} + None: If error occurs """ - async with self._driver.session(database=self._DATABASE) as session: - query = f""" - MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) - RETURN properties(r) as edge_properties - LIMIT 1 - """.format( - entity_name_label_source=entity_name_label_source, - entity_name_label_target=entity_name_label_target, - ) + try: + entity_name_label_source = source_node_id.strip('"') + entity_name_label_target = target_node_id.strip('"') - result = await session.run(query) - record = await result.single() - if record: - result = dict(record["edge_properties"]) - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + async with self._driver.session(database=self._DATABASE) as session: + query = f""" + MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) + RETURN properties(r) as edge_properties + LIMIT 1 + """.format( + entity_name_label_source=entity_name_label_source, + entity_name_label_target=entity_name_label_target, ) - return result - else: - return None + + result = await session.run(query) + record = await result.single() + if record and "edge_properties" in record: + try: + result = dict(record["edge_properties"]) + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "target_id": None, + } + for key, default_value in required_keys.items(): + if key not in result: + result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + ) + return result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return {"weight": 0.0, "source_id": None, "target_id": None} + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + ) + # Return default edge properties when no edge found + return {"weight": 0.0, "source_id": None, "target_id": None} + + except Exception as e: + logger.error( + f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + # Return default edge properties on error + return {"weight": 0.0, "source_id": None, "target_id": None} async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: node_label = source_node_id.strip('"') @@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = node_id.strip('"') + label = await self._ensure_label(node_id) properties = node_data async def _do_upsert(tx: AsyncManagedTransaction): @@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage): neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, ) ), ) @@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage): target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge """ - source_node_label = source_node_id.strip('"') - target_node_label = target_node_id.strip('"') + source_label = await self._ensure_label(source_node_id) + target_label = await self._ensure_label(target_node_id) edge_properties = edge_data async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" - MATCH (source:`{source_node_label}`) + MATCH (source:`{source_label}`) WITH source - MATCH (target:`{target_node_label}`) + MATCH (target:`{target_label}`) MERGE (source)-[r:DIRECTED]->(target) SET r += $properties RETURN r """ - await tx.run(query, properties=edge_properties) + result = await tx.run(query, properties=edge_properties) + record = await result.single() logger.debug( - f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}" + f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" ) try: diff --git a/lightrag/operate.py b/lightrag/operate.py index 04aad0d4..8cf77f57 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -237,25 +237,65 @@ async def _merge_edges_then_upsert( if await knowledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) - already_weights.append(already_edge["weight"]) - already_source_ids.extend( - split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) - ) - already_description.append(already_edge["description"]) - already_keywords.extend( - split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP]) - ) + # Handle the case where get_edge returns None or missing fields + if already_edge: + # Get weight with default 0.0 if missing + if "weight" in already_edge: + already_weights.append(already_edge["weight"]) + else: + logger.warning( + f"Edge between {src_id} and {tgt_id} missing weight field" + ) + already_weights.append(0.0) + # Get source_id with empty string default if missing or None + if "source_id" in already_edge and already_edge["source_id"] is not None: + already_source_ids.extend( + split_string_by_multi_markers( + already_edge["source_id"], [GRAPH_FIELD_SEP] + ) + ) + + # Get description with empty string default if missing or None + if ( + "description" in already_edge + and already_edge["description"] is not None + ): + already_description.append(already_edge["description"]) + + # Get keywords with empty string default if missing or None + if "keywords" in already_edge and already_edge["keywords"] is not None: + already_keywords.extend( + split_string_by_multi_markers( + already_edge["keywords"], [GRAPH_FIELD_SEP] + ) + ) + + # Process edges_data with None checks weight = sum([dp["weight"] for dp in edges_data] + already_weights) description = GRAPH_FIELD_SEP.join( - sorted(set([dp["description"] for dp in edges_data] + already_description)) + sorted( + set( + [dp["description"] for dp in edges_data if dp.get("description")] + + already_description + ) + ) ) keywords = GRAPH_FIELD_SEP.join( - sorted(set([dp["keywords"] for dp in edges_data] + already_keywords)) + sorted( + set( + [dp["keywords"] for dp in edges_data if dp.get("keywords")] + + already_keywords + ) + ) ) source_id = GRAPH_FIELD_SEP.join( - set([dp["source_id"] for dp in edges_data] + already_source_ids) + set( + [dp["source_id"] for dp in edges_data if dp.get("source_id")] + + already_source_ids + ) ) + for need_insert_id in [src_id, tgt_id]: if not (await knowledge_graph_inst.has_node(need_insert_id)): await knowledge_graph_inst.upsert_node(