Enhance Neo4j graph storage with error handling and label validation

- Add label existence check and validation methods in Neo4j implementation
- Improve error handling in get_node, get_edge, and upsert methods
- Add default values and logging for missing edge properties
- Ensure consistent label processing across graph storage methods
This commit is contained in:
Pankaj Kaushal
2025-02-14 16:04:06 +01:00
parent 4d58ff8bb4
commit cd81312659
2 changed files with 150 additions and 46 deletions

View File

@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
async def index_done_callback(self): async def index_done_callback(self):
print("KG successfully indexed.") print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool: async def _label_exists(self, label: str) -> bool:
entity_name_label = node_id.strip('"') """Check if a label exists in the Neo4j database."""
query = "CALL db.labels() YIELD label RETURN label"
try:
async with self._driver.session(database=self._DATABASE) as session:
result = await session.run(query)
labels = [record["label"] for record in await result.data()]
return label in labels
except Exception as e:
logger.error(f"Error checking label existence: {e}")
return False
async def _ensure_label(self, label: str) -> str:
"""Ensure a label exists by validating it."""
clean_label = label.strip('"')
if not await self._label_exists(clean_label):
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
return clean_label
async def has_node(self, node_id: str) -> bool:
entity_name_label = await self._ensure_label(node_id)
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
query = ( query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
return single_result["edgeExists"] return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
"""Get node by its label identifier.
Args:
node_id: The node label to look up
Returns:
dict: Node properties if found
None: If node not found
"""
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
entity_name_label = node_id.strip('"') entity_name_label = await self._ensure_label(node_id)
query = f"MATCH (n:`{entity_name_label}`) RETURN n" query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query) result = await session.run(query)
record = await result.single() record = await result.single()
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('"') """Find edge between two nodes identified by their labels.
entity_name_label_target = target_node_id.strip('"')
"""
Find all edges between nodes of two given labels
Args: Args:
source_node_label (str): Label of the source nodes source_node_id (str): Label of the source node
target_node_label (str): Label of the target nodes target_node_id (str): Label of the target node
Returns: Returns:
list: List of all relationships/edges found dict: Edge properties if found, with at least {"weight": 0.0}
None: If error occurs
""" """
async with self._driver.session(database=self._DATABASE) as session: try:
query = f""" entity_name_label_source = source_node_id.strip('"')
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) entity_name_label_target = target_node_id.strip('"')
RETURN properties(r) as edge_properties
LIMIT 1
""".format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
)
result = await session.run(query) async with self._driver.session(database=self._DATABASE) as session:
record = await result.single() query = f"""
if record: MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
result = dict(record["edge_properties"]) RETURN properties(r) as edge_properties
logger.debug( LIMIT 1
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" """.format(
entity_name_label_source=entity_name_label_source,
entity_name_label_target=entity_name_label_target,
) )
return result
else: result = await session.run(query)
return None record = await result.single()
if record and "edge_properties" in record:
try:
result = dict(record["edge_properties"])
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"target_id": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
)
# Return default edge properties when no edge found
return {"weight": 0.0, "source_id": None, "target_id": None}
except Exception as e:
logger.error(
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
)
# Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None}
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
node_label = source_node_id.strip('"') node_label = source_node_id.strip('"')
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
label = node_id.strip('"') label = await self._ensure_label(node_id)
properties = node_data properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction): async def _do_upsert(tx: AsyncManagedTransaction):
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
neo4jExceptions.ServiceUnavailable, neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError, neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable, neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
) )
), ),
) )
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
""" """
source_node_label = source_node_id.strip('"') source_label = await self._ensure_label(source_node_id)
target_node_label = target_node_id.strip('"') target_label = await self._ensure_label(target_node_id)
edge_properties = edge_data edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction): async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f""" query = f"""
MATCH (source:`{source_node_label}`) MATCH (source:`{source_label}`)
WITH source WITH source
MATCH (target:`{target_node_label}`) MATCH (target:`{target_label}`)
MERGE (source)-[r:DIRECTED]->(target) MERGE (source)-[r:DIRECTED]->(target)
SET r += $properties SET r += $properties
RETURN r RETURN r
""" """
await tx.run(query, properties=edge_properties) result = await tx.run(query, properties=edge_properties)
record = await result.single()
logger.debug( logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}" f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
) )
try: try:

View File

@@ -237,25 +237,65 @@ async def _merge_edges_then_upsert(
if await knowledge_graph_inst.has_edge(src_id, tgt_id): if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"]) # Handle the case where get_edge returns None or missing fields
already_source_ids.extend( if already_edge:
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) # Get weight with default 0.0 if missing
) if "weight" in already_edge:
already_description.append(already_edge["description"]) already_weights.append(already_edge["weight"])
already_keywords.extend( else:
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP]) logger.warning(
) f"Edge between {src_id} and {tgt_id} missing weight field"
)
already_weights.append(0.0)
# Get source_id with empty string default if missing or None
if "source_id" in already_edge and already_edge["source_id"] is not None:
already_source_ids.extend(
split_string_by_multi_markers(
already_edge["source_id"], [GRAPH_FIELD_SEP]
)
)
# Get description with empty string default if missing or None
if (
"description" in already_edge
and already_edge["description"] is not None
):
already_description.append(already_edge["description"])
# Get keywords with empty string default if missing or None
if "keywords" in already_edge and already_edge["keywords"] is not None:
already_keywords.extend(
split_string_by_multi_markers(
already_edge["keywords"], [GRAPH_FIELD_SEP]
)
)
# Process edges_data with None checks
weight = sum([dp["weight"] for dp in edges_data] + already_weights) weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join( description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description)) sorted(
set(
[dp["description"] for dp in edges_data if dp.get("description")]
+ already_description
)
)
) )
keywords = GRAPH_FIELD_SEP.join( keywords = GRAPH_FIELD_SEP.join(
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords)) sorted(
set(
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
+ already_keywords
)
)
) )
source_id = GRAPH_FIELD_SEP.join( source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in edges_data] + already_source_ids) set(
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
+ already_source_ids
)
) )
for need_insert_id in [src_id, tgt_id]: for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)): if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node( await knowledge_graph_inst.upsert_node(