Refactor Neo4J APOC fall back retrival implementaion

This commit is contained in:
yangdx
2025-03-08 04:28:54 +08:00
parent c07b592e1b
commit fcb04e47e5

View File

@@ -3,7 +3,7 @@ import inspect
import os
import re
from dataclasses import dataclass
from typing import Any, List, Dict, final
from typing import Any, final, Optional
import numpy as np
import configparser
@@ -304,7 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
)
return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
@@ -321,16 +320,17 @@ class Neo4JStorage(BaseGraphStorage):
"""
result = await session.run(query)
try:
records = await result.fetch(2) # Get up to 2 records to check for duplicates
await result.consume() # Ensure result is fully consumed before processing records
if len(records) > 1:
logger.warning(
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
)
if records:
try:
result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {result}")
edge_result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {edge_result}")
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
@@ -339,17 +339,17 @@ class Neo4JStorage(BaseGraphStorage):
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
if key not in edge_result:
edge_result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
)
return result
return edge_result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
@@ -373,8 +373,6 @@ class Neo4JStorage(BaseGraphStorage):
"keywords": None,
"source_id": None,
}
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(
@@ -685,12 +683,14 @@ class Neo4JStorage(BaseGraphStorage):
await result_set.consume() # Ensure result set is consumed
except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}")
if label != "*":
logger.warning(
f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation"
"Neo4j: falling back to basic Cypher recursive search..."
)
if inclusive:
logger.warning(
"Inclusive search mode is not supported in recursive query, using exact matching"
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
)
return await self._robust_fallback(label, max_depth, min_degree)
@@ -698,17 +698,21 @@ class Neo4JStorage(BaseGraphStorage):
async def _robust_fallback(
self, label: str, max_depth: int, min_degree: int = 0
) -> Dict[str, List[Dict]]:
) -> KnowledgeGraph:
"""
Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures.
"""
result = {"nodes": [], "edges": []}
result = KnowledgeGraph()
visited_nodes = set()
visited_edges = set()
async def traverse(current_label: str, current_depth: int):
async def traverse(
node: KnowledgeGraphNode,
edge: Optional[KnowledgeGraphEdge],
current_depth: int,
):
# Check traversal limits
if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}")
@@ -717,62 +721,101 @@ class Neo4JStorage(BaseGraphStorage):
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
return
# Get current node details
node = await self.get_node(current_label)
if not node:
# Check if node already visited
if node.id in visited_nodes:
return
node_id = f"{current_label}"
if node_id in visited_nodes:
return
visited_nodes.add(node_id)
# Add node data with label as ID
result["nodes"].append(
{"id": current_label, "labels": current_label, "properties": node}
)
# Get connected nodes that meet the degree requirement
# Note: We don't need to check a's degree since it's the current node
# and was already validated in the previous iteration
query = f"""
MATCH (a:`{current_label}`)-[r]-(b)
WITH r, b,
COUNT((b)--()) AS b_degree
WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
RETURN r, b
"""
# Get all edges and target nodes
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
results = await session.run(query, {"min_degree": min_degree})
async for record in results:
# Handle edges
query = """
MATCH (a)-[r]-(b)
WHERE id(a) = toInteger($node_id)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, {"node_id": node.id})
# Get all records and release database connection
records = await results.fetch()
await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree
if current_depth > 1 and len(records) < min_degree:
return
# Add current node to result
result.nodes.append(node)
visited_nodes.add(node.id)
# Add edge to result if it exists and not already added
if edge and edge.id not in visited_edges:
result.edges.append(edge)
visited_edges.add(edge.id)
# Prepare nodes and edges for recursive processing
nodes_to_process = []
for record in records:
rel = record["r"]
edge_id = f"{rel.id}_{rel.type}"
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
if b_node.labels: # Only process if target node has labels
target_label = list(b_node.labels)[0]
result["edges"].append(
{
"id": f"{current_label}_{target_label}",
"type": rel.type,
"source": current_label,
"target": target_label,
"properties": dict(rel),
}
)
visited_edges.add(edge_id)
target_id = str(record["target_id"])
# Continue traversal
await traverse(target_label, current_depth + 1)
if b_node.labels: # Only process if target node has labels
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=target_id,
labels=list(b_node.labels),
properties=dict(b_node),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=edge_id,
type=rel.type,
source=node.id,
target=target_id,
properties=dict(rel),
)
nodes_to_process.append((target_node, target_edge))
else:
logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
await traverse(label, 0)
# Process nodes after releasing database connection
for target_node, target_edge in nodes_to_process:
await traverse(target_node, target_edge, current_depth + 1)
# Get the starting node's data
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (n:`{label}`)
RETURN id(n) as node_id, n
"""
node_result = await session.run(query)
try:
node_record = await node_result.single()
if not node_record:
return result
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=str(node_record["node_id"]),
labels=list(node_record["n"].labels),
properties=dict(node_record["n"]),
)
finally:
await node_result.consume() # Ensure results are consumed
# Start traversal with the initial node
await traverse(start_node, None, 0)
return result
async def get_all_labels(self) -> list[str]: