Merge pull request #780 from spo0nman/fix-bug-778

Enhance Neo4j graph storage with error handling and label validation
This commit is contained in:
zrguo
2025-02-16 19:30:32 +08:00
committed by GitHub
2 changed files with 150 additions and 46 deletions

View File

@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
async def index_done_callback(self):
print("KG successfully indexed.")
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
async def _label_exists(self, label: str) -> bool:
"""Check if a label exists in the Neo4j database."""
query = "CALL db.labels() YIELD label RETURN label"
try:
async with self._driver.session(database=self._DATABASE) as session:
result = await session.run(query)
labels = [record["label"] for record in await result.data()]
return label in labels
except Exception as e:
logger.error(f"Error checking label existence: {e}")
return False
async def _ensure_label(self, label: str) -> str:
"""Ensure a label exists by validating it."""
clean_label = label.strip('"')
if not await self._label_exists(clean_label):
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
return clean_label
async def has_node(self, node_id: str) -> bool:
entity_name_label = await self._ensure_label(node_id)
async with self._driver.session(database=self._DATABASE) as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
return single_result["edgeExists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
"""Get node by its label identifier.
Args:
node_id: The node label to look up
Returns:
dict: Node properties if found
None: If node not found
"""
async with self._driver.session(database=self._DATABASE) as session:
entity_name_label = node_id.strip('"')
entity_name_label = await self._ensure_label(node_id)
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
record = await result.single()
@@ -226,18 +253,20 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
"""
Find all edges between nodes of two given labels
"""Find edge between two nodes identified by their labels.
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
source_node_id (str): Label of the source node
target_node_id (str): Label of the target node
Returns:
list: List of all relationships/edges found
dict: Edge properties if found, with at least {"weight": 0.0}
None: If error occurs
"""
try:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session(database=self._DATABASE) as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
@@ -250,14 +279,47 @@ class Neo4JStorage(BaseGraphStorage):
result = await session.run(query)
record = await result.single()
if record:
if record and "edge_properties" in record:
try:
result = dict(record["edge_properties"])
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"target_id": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
else:
return None
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
)
# Return default edge properties when no edge found
return {"weight": 0.0, "source_id": None, "target_id": None}
except Exception as e:
logger.error(
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
)
# Return default edge properties on error
return {"weight": 0.0, "source_id": None, "target_id": None}
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
node_label = source_node_id.strip('"')
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('"')
label = await self._ensure_label(node_id)
properties = node_data
async def _do_upsert(tx: AsyncManagedTransaction):
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
source_label = await self._ensure_label(source_node_id)
target_label = await self._ensure_label(target_node_id)
edge_properties = edge_data
async def _do_upsert_edge(tx: AsyncManagedTransaction):
query = f"""
MATCH (source:`{source_node_label}`)
MATCH (source:`{source_label}`)
WITH source
MATCH (target:`{target_node_label}`)
MATCH (target:`{target_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += $properties
RETURN r
"""
await tx.run(query, properties=edge_properties)
result = await tx.run(query, properties=edge_properties)
record = await result.single()
logger.debug(
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
)
try:

View File

@@ -237,25 +237,65 @@ async def _merge_edges_then_upsert(
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
# Handle the case where get_edge returns None or missing fields
if already_edge:
# Get weight with default 0.0 if missing
if "weight" in already_edge:
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
else:
logger.warning(
f"Edge between {src_id} and {tgt_id} missing weight field"
)
already_weights.append(0.0)
# Get source_id with empty string default if missing or None
if "source_id" in already_edge and already_edge["source_id"] is not None:
already_source_ids.extend(
split_string_by_multi_markers(
already_edge["source_id"], [GRAPH_FIELD_SEP]
)
already_description.append(already_edge["description"])
already_keywords.extend(
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
)
# Get description with empty string default if missing or None
if (
"description" in already_edge
and already_edge["description"] is not None
):
already_description.append(already_edge["description"])
# Get keywords with empty string default if missing or None
if "keywords" in already_edge and already_edge["keywords"] is not None:
already_keywords.extend(
split_string_by_multi_markers(
already_edge["keywords"], [GRAPH_FIELD_SEP]
)
)
# Process edges_data with None checks
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
description = GRAPH_FIELD_SEP.join(
sorted(set([dp["description"] for dp in edges_data] + already_description))
sorted(
set(
[dp["description"] for dp in edges_data if dp.get("description")]
+ already_description
)
)
)
keywords = GRAPH_FIELD_SEP.join(
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
sorted(
set(
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
+ already_keywords
)
)
)
source_id = GRAPH_FIELD_SEP.join(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
set(
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
+ already_source_ids
)
)
for need_insert_id in [src_id, tgt_id]:
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(