Refactor Neo4J APOC fall back retrival implementaion
This commit is contained in:
@@ -3,7 +3,7 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Dict, final
|
||||
from typing import Any, final, Optional
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
@@ -304,7 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
return degrees
|
||||
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
@@ -321,16 +320,17 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"""
|
||||
|
||||
result = await session.run(query)
|
||||
try:
|
||||
records = await result.fetch(2) # Get up to 2 records to check for duplicates
|
||||
await result.consume() # Ensure result is fully consumed before processing records
|
||||
|
||||
if len(records) > 1:
|
||||
logger.warning(
|
||||
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
|
||||
)
|
||||
if records:
|
||||
try:
|
||||
result = dict(records[0]["edge_properties"])
|
||||
logger.debug(f"Result: {result}")
|
||||
edge_result = dict(records[0]["edge_properties"])
|
||||
logger.debug(f"Result: {edge_result}")
|
||||
# Ensure required keys exist with defaults
|
||||
required_keys = {
|
||||
"weight": 0.0,
|
||||
@@ -339,17 +339,17 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"keywords": None,
|
||||
}
|
||||
for key, default_value in required_keys.items():
|
||||
if key not in result:
|
||||
result[key] = default_value
|
||||
if key not in edge_result:
|
||||
edge_result[key] = default_value
|
||||
logger.warning(
|
||||
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
||||
f"missing {key}, using default: {default_value}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
|
||||
)
|
||||
return result
|
||||
return edge_result
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logger.error(
|
||||
f"Error processing edge properties between {entity_name_label_source} "
|
||||
@@ -373,8 +373,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"keywords": None,
|
||||
"source_id": None,
|
||||
}
|
||||
finally:
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -685,12 +683,14 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
await result_set.consume() # Ensure result set is consumed
|
||||
|
||||
except neo4jExceptions.ClientError as e:
|
||||
logger.warning(f"APOC plugin error: {str(e)}")
|
||||
if label != "*":
|
||||
logger.warning(
|
||||
f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation"
|
||||
"Neo4j: falling back to basic Cypher recursive search..."
|
||||
)
|
||||
if inclusive:
|
||||
logger.warning(
|
||||
"Inclusive search mode is not supported in recursive query, using exact matching"
|
||||
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
|
||||
)
|
||||
return await self._robust_fallback(label, max_depth, min_degree)
|
||||
|
||||
@@ -698,17 +698,21 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def _robust_fallback(
|
||||
self, label: str, max_depth: int, min_degree: int = 0
|
||||
) -> Dict[str, List[Dict]]:
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Fallback implementation when APOC plugin is not available or incompatible.
|
||||
This method implements the same functionality as get_knowledge_graph but uses
|
||||
only basic Cypher queries and recursive traversal instead of APOC procedures.
|
||||
"""
|
||||
result = {"nodes": [], "edges": []}
|
||||
result = KnowledgeGraph()
|
||||
visited_nodes = set()
|
||||
visited_edges = set()
|
||||
|
||||
async def traverse(current_label: str, current_depth: int):
|
||||
async def traverse(
|
||||
node: KnowledgeGraphNode,
|
||||
edge: Optional[KnowledgeGraphEdge],
|
||||
current_depth: int,
|
||||
):
|
||||
# Check traversal limits
|
||||
if current_depth > max_depth:
|
||||
logger.debug(f"Reached max depth: {max_depth}")
|
||||
@@ -717,62 +721,101 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
|
||||
return
|
||||
|
||||
# Get current node details
|
||||
node = await self.get_node(current_label)
|
||||
if not node:
|
||||
# Check if node already visited
|
||||
if node.id in visited_nodes:
|
||||
return
|
||||
|
||||
node_id = f"{current_label}"
|
||||
if node_id in visited_nodes:
|
||||
return
|
||||
visited_nodes.add(node_id)
|
||||
|
||||
# Add node data with label as ID
|
||||
result["nodes"].append(
|
||||
{"id": current_label, "labels": current_label, "properties": node}
|
||||
)
|
||||
|
||||
# Get connected nodes that meet the degree requirement
|
||||
# Note: We don't need to check a's degree since it's the current node
|
||||
# and was already validated in the previous iteration
|
||||
query = f"""
|
||||
MATCH (a:`{current_label}`)-[r]-(b)
|
||||
WITH r, b,
|
||||
COUNT((b)--()) AS b_degree
|
||||
WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
|
||||
RETURN r, b
|
||||
"""
|
||||
# Get all edges and target nodes
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
results = await session.run(query, {"min_degree": min_degree})
|
||||
async for record in results:
|
||||
# Handle edges
|
||||
query = """
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE id(a) = toInteger($node_id)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, {"node_id": node.id})
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch()
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Nodes not connected to start node need to check degree
|
||||
if current_depth > 1 and len(records) < min_degree:
|
||||
return
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(node)
|
||||
visited_nodes.add(node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if edge and edge.id not in visited_edges:
|
||||
result.edges.append(edge)
|
||||
visited_edges.add(edge.id)
|
||||
|
||||
# Prepare nodes and edges for recursive processing
|
||||
nodes_to_process = []
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = f"{rel.id}_{rel.type}"
|
||||
edge_id = str(record["edge_id"])
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
if b_node.labels: # Only process if target node has labels
|
||||
target_label = list(b_node.labels)[0]
|
||||
result["edges"].append(
|
||||
{
|
||||
"id": f"{current_label}_{target_label}",
|
||||
"type": rel.type,
|
||||
"source": current_label,
|
||||
"target": target_label,
|
||||
"properties": dict(rel),
|
||||
}
|
||||
)
|
||||
visited_edges.add(edge_id)
|
||||
target_id = str(record["target_id"])
|
||||
|
||||
# Continue traversal
|
||||
await traverse(target_label, current_depth + 1)
|
||||
if b_node.labels: # Only process if target node has labels
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=target_id,
|
||||
labels=list(b_node.labels),
|
||||
properties=dict(b_node),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=rel.type,
|
||||
source=node.id,
|
||||
target=target_id,
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
nodes_to_process.append((target_node, target_edge))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing labels on target node"
|
||||
)
|
||||
|
||||
await traverse(label, 0)
|
||||
# Process nodes after releasing database connection
|
||||
for target_node, target_edge in nodes_to_process:
|
||||
await traverse(target_node, target_edge, current_depth + 1)
|
||||
|
||||
# Get the starting node's data
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = f"""
|
||||
MATCH (n:`{label}`)
|
||||
RETURN id(n) as node_id, n
|
||||
"""
|
||||
node_result = await session.run(query)
|
||||
try:
|
||||
node_record = await node_result.single()
|
||||
if not node_record:
|
||||
return result
|
||||
|
||||
# Create initial KnowledgeGraphNode
|
||||
start_node = KnowledgeGraphNode(
|
||||
id=str(node_record["node_id"]),
|
||||
labels=list(node_record["n"].labels),
|
||||
properties=dict(node_record["n"]),
|
||||
)
|
||||
finally:
|
||||
await node_result.consume() # Ensure results are consumed
|
||||
|
||||
# Start traversal with the initial node
|
||||
await traverse(start_node, None, 0)
|
||||
|
||||
return result
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
|
Reference in New Issue
Block a user