Enhance Neo4j graph storage with error handling and label validation
- Add label existence check and validation methods in Neo4j implementation - Improve error handling in get_node, get_edge, and upsert methods - Add default values and logging for missing edge properties - Ensure consistent label processing across graph storage methods
This commit is contained in:
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
print("KG successfully indexed.")
|
print("KG successfully indexed.")
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def _label_exists(self, label: str) -> bool:
|
||||||
entity_name_label = node_id.strip('"')
|
"""Check if a label exists in the Neo4j database."""
|
||||||
|
query = "CALL db.labels() YIELD label RETURN label"
|
||||||
|
try:
|
||||||
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
|
result = await session.run(query)
|
||||||
|
labels = [record["label"] for record in await result.data()]
|
||||||
|
return label in labels
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking label existence: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _ensure_label(self, label: str) -> str:
|
||||||
|
"""Ensure a label exists by validating it."""
|
||||||
|
clean_label = label.strip('"')
|
||||||
|
if not await self._label_exists(clean_label):
|
||||||
|
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
|
||||||
|
return clean_label
|
||||||
|
|
||||||
|
async def has_node(self, node_id: str) -> bool:
|
||||||
|
entity_name_label = await self._ensure_label(node_id)
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = (
|
query = (
|
||||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||||
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
return single_result["edgeExists"]
|
return single_result["edgeExists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
|
"""Get node by its label identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: The node label to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Node properties if found
|
||||||
|
None: If node not found
|
||||||
|
"""
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
entity_name_label = node_id.strip('"')
|
entity_name_label = await self._ensure_label(node_id)
|
||||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
record = await result.single()
|
record = await result.single()
|
||||||
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
entity_name_label_source = source_node_id.strip('"')
|
"""Find edge between two nodes identified by their labels.
|
||||||
entity_name_label_target = target_node_id.strip('"')
|
|
||||||
"""
|
|
||||||
Find all edges between nodes of two given labels
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_node_label (str): Label of the source nodes
|
source_node_id (str): Label of the source node
|
||||||
target_node_label (str): Label of the target nodes
|
target_node_id (str): Label of the target node
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: List of all relationships/edges found
|
dict: Edge properties if found, with at least {"weight": 0.0}
|
||||||
|
None: If error occurs
|
||||||
"""
|
"""
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
try:
|
||||||
query = f"""
|
entity_name_label_source = source_node_id.strip('"')
|
||||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
entity_name_label_target = target_node_id.strip('"')
|
||||||
RETURN properties(r) as edge_properties
|
|
||||||
LIMIT 1
|
|
||||||
""".format(
|
|
||||||
entity_name_label_source=entity_name_label_source,
|
|
||||||
entity_name_label_target=entity_name_label_target,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.run(query)
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
record = await result.single()
|
query = f"""
|
||||||
if record:
|
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||||
result = dict(record["edge_properties"])
|
RETURN properties(r) as edge_properties
|
||||||
logger.debug(
|
LIMIT 1
|
||||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
""".format(
|
||||||
|
entity_name_label_source=entity_name_label_source,
|
||||||
|
entity_name_label_target=entity_name_label_target,
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
else:
|
result = await session.run(query)
|
||||||
return None
|
record = await result.single()
|
||||||
|
if record and "edge_properties" in record:
|
||||||
|
try:
|
||||||
|
result = dict(record["edge_properties"])
|
||||||
|
# Ensure required keys exist with defaults
|
||||||
|
required_keys = {
|
||||||
|
"weight": 0.0,
|
||||||
|
"source_id": None,
|
||||||
|
"target_id": None,
|
||||||
|
}
|
||||||
|
for key, default_value in required_keys.items():
|
||||||
|
if key not in result:
|
||||||
|
result[key] = default_value
|
||||||
|
logger.warning(
|
||||||
|
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
||||||
|
f"missing {key}, using default: {default_value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except (KeyError, TypeError, ValueError) as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing edge properties between {entity_name_label_source} "
|
||||||
|
f"and {entity_name_label_target}: {str(e)}"
|
||||||
|
)
|
||||||
|
# Return default edge properties on error
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
||||||
|
)
|
||||||
|
# Return default edge properties when no edge found
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
# Return default edge properties on error
|
||||||
|
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||||
node_label = source_node_id.strip('"')
|
node_label = source_node_id.strip('"')
|
||||||
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
node_id: The unique identifier for the node (used as label)
|
node_id: The unique identifier for the node (used as label)
|
||||||
node_data: Dictionary of node properties
|
node_data: Dictionary of node properties
|
||||||
"""
|
"""
|
||||||
label = node_id.strip('"')
|
label = await self._ensure_label(node_id)
|
||||||
properties = node_data
|
properties = node_data
|
||||||
|
|
||||||
async def _do_upsert(tx: AsyncManagedTransaction):
|
async def _do_upsert(tx: AsyncManagedTransaction):
|
||||||
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
neo4jExceptions.ServiceUnavailable,
|
neo4jExceptions.ServiceUnavailable,
|
||||||
neo4jExceptions.TransientError,
|
neo4jExceptions.TransientError,
|
||||||
neo4jExceptions.WriteServiceUnavailable,
|
neo4jExceptions.WriteServiceUnavailable,
|
||||||
|
neo4jExceptions.ClientError,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
target_node_id (str): Label of the target node (used as identifier)
|
target_node_id (str): Label of the target node (used as identifier)
|
||||||
edge_data (dict): Dictionary of properties to set on the edge
|
edge_data (dict): Dictionary of properties to set on the edge
|
||||||
"""
|
"""
|
||||||
source_node_label = source_node_id.strip('"')
|
source_label = await self._ensure_label(source_node_id)
|
||||||
target_node_label = target_node_id.strip('"')
|
target_label = await self._ensure_label(target_node_id)
|
||||||
edge_properties = edge_data
|
edge_properties = edge_data
|
||||||
|
|
||||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (source:`{source_node_label}`)
|
MATCH (source:`{source_label}`)
|
||||||
WITH source
|
WITH source
|
||||||
MATCH (target:`{target_node_label}`)
|
MATCH (target:`{target_label}`)
|
||||||
MERGE (source)-[r:DIRECTED]->(target)
|
MERGE (source)-[r:DIRECTED]->(target)
|
||||||
SET r += $properties
|
SET r += $properties
|
||||||
RETURN r
|
RETURN r
|
||||||
"""
|
"""
|
||||||
await tx.run(query, properties=edge_properties)
|
result = await tx.run(query, properties=edge_properties)
|
||||||
|
record = await result.single()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
|
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@@ -237,25 +237,65 @@ async def _merge_edges_then_upsert(
|
|||||||
|
|
||||||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||||
already_weights.append(already_edge["weight"])
|
# Handle the case where get_edge returns None or missing fields
|
||||||
already_source_ids.extend(
|
if already_edge:
|
||||||
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
# Get weight with default 0.0 if missing
|
||||||
)
|
if "weight" in already_edge:
|
||||||
already_description.append(already_edge["description"])
|
already_weights.append(already_edge["weight"])
|
||||||
already_keywords.extend(
|
else:
|
||||||
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
|
logger.warning(
|
||||||
)
|
f"Edge between {src_id} and {tgt_id} missing weight field"
|
||||||
|
)
|
||||||
|
already_weights.append(0.0)
|
||||||
|
|
||||||
|
# Get source_id with empty string default if missing or None
|
||||||
|
if "source_id" in already_edge and already_edge["source_id"] is not None:
|
||||||
|
already_source_ids.extend(
|
||||||
|
split_string_by_multi_markers(
|
||||||
|
already_edge["source_id"], [GRAPH_FIELD_SEP]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get description with empty string default if missing or None
|
||||||
|
if (
|
||||||
|
"description" in already_edge
|
||||||
|
and already_edge["description"] is not None
|
||||||
|
):
|
||||||
|
already_description.append(already_edge["description"])
|
||||||
|
|
||||||
|
# Get keywords with empty string default if missing or None
|
||||||
|
if "keywords" in already_edge and already_edge["keywords"] is not None:
|
||||||
|
already_keywords.extend(
|
||||||
|
split_string_by_multi_markers(
|
||||||
|
already_edge["keywords"], [GRAPH_FIELD_SEP]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process edges_data with None checks
|
||||||
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
||||||
description = GRAPH_FIELD_SEP.join(
|
description = GRAPH_FIELD_SEP.join(
|
||||||
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
sorted(
|
||||||
|
set(
|
||||||
|
[dp["description"] for dp in edges_data if dp.get("description")]
|
||||||
|
+ already_description
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
keywords = GRAPH_FIELD_SEP.join(
|
keywords = GRAPH_FIELD_SEP.join(
|
||||||
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
|
sorted(
|
||||||
|
set(
|
||||||
|
[dp["keywords"] for dp in edges_data if dp.get("keywords")]
|
||||||
|
+ already_keywords
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
source_id = GRAPH_FIELD_SEP.join(
|
source_id = GRAPH_FIELD_SEP.join(
|
||||||
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
set(
|
||||||
|
[dp["source_id"] for dp in edges_data if dp.get("source_id")]
|
||||||
|
+ already_source_ids
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for need_insert_id in [src_id, tgt_id]:
|
for need_insert_id in [src_id, tgt_id]:
|
||||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||||
await knowledge_graph_inst.upsert_node(
|
await knowledge_graph_inst.upsert_node(
|
||||||
|
Reference in New Issue
Block a user