Merge pull request #780 from spo0nman/fix-bug-778
Enhance Neo4j graph storage with error handling and label validation
This commit is contained in:
@@ -143,9 +143,27 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
async def _label_exists(self, label: str) -> bool:
|
||||
"""Check if a label exists in the Neo4j database."""
|
||||
query = "CALL db.labels() YIELD label RETURN label"
|
||||
try:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
result = await session.run(query)
|
||||
labels = [record["label"] for record in await result.data()]
|
||||
return label in labels
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking label existence: {e}")
|
||||
return False
|
||||
|
||||
async def _ensure_label(self, label: str) -> str:
|
||||
"""Ensure a label exists by validating it."""
|
||||
clean_label = label.strip('"')
|
||||
if not await self._label_exists(clean_label):
|
||||
logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
|
||||
return clean_label
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = await self._ensure_label(node_id)
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = (
|
||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||
@@ -174,8 +192,17 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return single_result["edgeExists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""Get node by its label identifier.
|
||||
|
||||
Args:
|
||||
node_id: The node label to look up
|
||||
|
||||
Returns:
|
||||
dict: Node properties if found
|
||||
None: If node not found
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
entity_name_label = node_id.strip('"')
|
||||
entity_name_label = await self._ensure_label(node_id)
|
||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
@@ -226,38 +253,73 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
"""Find edge between two nodes identified by their labels.
|
||||
|
||||
Args:
|
||||
source_node_label (str): Label of the source nodes
|
||||
target_node_label (str): Label of the target nodes
|
||||
source_node_id (str): Label of the source node
|
||||
target_node_id (str): Label of the target node
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
dict: Edge properties if found, with at least {"weight": 0.0}
|
||||
None: If error occurs
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
""".format(
|
||||
entity_name_label_source=entity_name_label_source,
|
||||
entity_name_label_target=entity_name_label_target,
|
||||
)
|
||||
try:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
if record:
|
||||
result = dict(record["edge_properties"])
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
""".format(
|
||||
entity_name_label_source=entity_name_label_source,
|
||||
entity_name_label_target=entity_name_label_target,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
result = await session.run(query)
|
||||
record = await result.single()
|
||||
if record and "edge_properties" in record:
|
||||
try:
|
||||
result = dict(record["edge_properties"])
|
||||
# Ensure required keys exist with defaults
|
||||
required_keys = {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"target_id": None,
|
||||
}
|
||||
for key, default_value in required_keys.items():
|
||||
if key not in result:
|
||||
result[key] = default_value
|
||||
logger.warning(
|
||||
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
||||
f"missing {key}, using default: {default_value}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||
)
|
||||
return result
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logger.error(
|
||||
f"Error processing edge properties between {entity_name_label_source} "
|
||||
f"and {entity_name_label_target}: {str(e)}"
|
||||
)
|
||||
# Return default edge properties on error
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
||||
)
|
||||
# Return default edge properties when no edge found
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
|
||||
)
|
||||
# Return default edge properties on error
|
||||
return {"weight": 0.0, "source_id": None, "target_id": None}
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
node_label = source_node_id.strip('"')
|
||||
@@ -310,7 +372,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
node_id: The unique identifier for the node (used as label)
|
||||
node_data: Dictionary of node properties
|
||||
"""
|
||||
label = node_id.strip('"')
|
||||
label = await self._ensure_label(node_id)
|
||||
properties = node_data
|
||||
|
||||
async def _do_upsert(tx: AsyncManagedTransaction):
|
||||
@@ -338,6 +400,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
neo4jExceptions.ClientError,
|
||||
)
|
||||
),
|
||||
)
|
||||
@@ -352,22 +415,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
target_node_id (str): Label of the target node (used as identifier)
|
||||
edge_data (dict): Dictionary of properties to set on the edge
|
||||
"""
|
||||
source_node_label = source_node_id.strip('"')
|
||||
target_node_label = target_node_id.strip('"')
|
||||
source_label = await self._ensure_label(source_node_id)
|
||||
target_label = await self._ensure_label(target_node_id)
|
||||
edge_properties = edge_data
|
||||
|
||||
async def _do_upsert_edge(tx: AsyncManagedTransaction):
|
||||
query = f"""
|
||||
MATCH (source:`{source_node_label}`)
|
||||
MATCH (source:`{source_label}`)
|
||||
WITH source
|
||||
MATCH (target:`{target_node_label}`)
|
||||
MATCH (target:`{target_label}`)
|
||||
MERGE (source)-[r:DIRECTED]->(target)
|
||||
SET r += $properties
|
||||
RETURN r
|
||||
"""
|
||||
await tx.run(query, properties=edge_properties)
|
||||
result = await tx.run(query, properties=edge_properties)
|
||||
record = await result.single()
|
||||
logger.debug(
|
||||
f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
|
||||
f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
|
||||
)
|
||||
|
||||
try:
|
||||
|
Reference in New Issue
Block a user