Refactor Neo4J APOC fall back retrival implementaion

This commit is contained in:
yangdx
2025-03-08 04:28:54 +08:00
parent c07b592e1b
commit fcb04e47e5

View File

@@ -3,7 +3,7 @@ import inspect
import os import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Dict, final from typing import Any, final, Optional
import numpy as np import numpy as np
import configparser import configparser
@@ -304,7 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
) )
return degrees return degrees
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
@@ -321,60 +320,59 @@ class Neo4JStorage(BaseGraphStorage):
""" """
result = await session.run(query) result = await session.run(query)
try: records = await result.fetch(2) # Get up to 2 records to check for duplicates
records = await result.fetch(2) # Get up to 2 records to check for duplicates await result.consume() # Ensure result is fully consumed before processing records
if len(records) > 1:
logger.warning( if len(records) > 1:
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." logger.warning(
) f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
if records:
try:
result = dict(records[0]["edge_properties"])
logger.debug(f"Result: {result}")
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in result:
result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
)
return result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
) )
# Return default edge properties when no edge found if records:
return { try:
"weight": 0.0, edge_result = dict(records[0]["edge_properties"])
"description": None, logger.debug(f"Result: {edge_result}")
"keywords": None, # Ensure required keys exist with defaults
"source_id": None, required_keys = {
} "weight": 0.0,
finally: "source_id": None,
await result.consume() # Ensure result is fully consumed "description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in edge_result:
edge_result[key] = default_value
logger.warning(
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
f"missing {key}, using default: {default_value}"
)
logger.debug(
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
)
return edge_result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {entity_name_label_source} "
f"and {entity_name_label_target}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
logger.debug(
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
)
# Return default edge properties when no edge found
return {
"weight": 0.0,
"description": None,
"keywords": None,
"source_id": None,
}
except Exception as e: except Exception as e:
logger.error( logger.error(
@@ -685,30 +683,36 @@ class Neo4JStorage(BaseGraphStorage):
await result_set.consume() # Ensure result set is consumed await result_set.consume() # Ensure result set is consumed
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
logger.warning( logger.warning(f"APOC plugin error: {str(e)}")
f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation" if label != "*":
)
if inclusive:
logger.warning( logger.warning(
"Inclusive search mode is not supported in recursive query, using exact matching" "Neo4j: falling back to basic Cypher recursive search..."
) )
return await self._robust_fallback(label, max_depth, min_degree) if inclusive:
logger.warning(
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
)
return await self._robust_fallback(label, max_depth, min_degree)
return result return result
async def _robust_fallback( async def _robust_fallback(
self, label: str, max_depth: int, min_degree: int = 0 self, label: str, max_depth: int, min_degree: int = 0
) -> Dict[str, List[Dict]]: ) -> KnowledgeGraph:
""" """
Fallback implementation when APOC plugin is not available or incompatible. Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures. only basic Cypher queries and recursive traversal instead of APOC procedures.
""" """
result = {"nodes": [], "edges": []} result = KnowledgeGraph()
visited_nodes = set() visited_nodes = set()
visited_edges = set() visited_edges = set()
async def traverse(current_label: str, current_depth: int): async def traverse(
node: KnowledgeGraphNode,
edge: Optional[KnowledgeGraphEdge],
current_depth: int,
):
# Check traversal limits # Check traversal limits
if current_depth > max_depth: if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}") logger.debug(f"Reached max depth: {max_depth}")
@@ -717,62 +721,101 @@ class Neo4JStorage(BaseGraphStorage):
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
return return
# Get current node details # Check if node already visited
node = await self.get_node(current_label) if node.id in visited_nodes:
if not node:
return return
node_id = f"{current_label}" # Get all edges and target nodes
if node_id in visited_nodes:
return
visited_nodes.add(node_id)
# Add node data with label as ID
result["nodes"].append(
{"id": current_label, "labels": current_label, "properties": node}
)
# Get connected nodes that meet the degree requirement
# Note: We don't need to check a's degree since it's the current node
# and was already validated in the previous iteration
query = f"""
MATCH (a:`{current_label}`)-[r]-(b)
WITH r, b,
COUNT((b)--()) AS b_degree
WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
RETURN r, b
"""
async with self._driver.session( async with self._driver.session(
database=self._DATABASE, default_access_mode="READ" database=self._DATABASE, default_access_mode="READ"
) as session: ) as session:
results = await session.run(query, {"min_degree": min_degree}) query = """
async for record in results: MATCH (a)-[r]-(b)
# Handle edges WHERE id(a) = toInteger($node_id)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, {"node_id": node.id})
# Get all records and release database connection
records = await results.fetch()
await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree
if current_depth > 1 and len(records) < min_degree:
return
# Add current node to result
result.nodes.append(node)
visited_nodes.add(node.id)
# Add edge to result if it exists and not already added
if edge and edge.id not in visited_edges:
result.edges.append(edge)
visited_edges.add(edge.id)
# Prepare nodes and edges for recursive processing
nodes_to_process = []
for record in records:
rel = record["r"] rel = record["r"]
edge_id = f"{rel.id}_{rel.type}" edge_id = str(record["edge_id"])
if edge_id not in visited_edges: if edge_id not in visited_edges:
b_node = record["b"] b_node = record["b"]
if b_node.labels: # Only process if target node has labels target_id = str(record["target_id"])
target_label = list(b_node.labels)[0]
result["edges"].append(
{
"id": f"{current_label}_{target_label}",
"type": rel.type,
"source": current_label,
"target": target_label,
"properties": dict(rel),
}
)
visited_edges.add(edge_id)
# Continue traversal if b_node.labels: # Only process if target node has labels
await traverse(target_label, current_depth + 1) # Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=target_id,
labels=list(b_node.labels),
properties=dict(b_node),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=edge_id,
type=rel.type,
source=node.id,
target=target_id,
properties=dict(rel),
)
nodes_to_process.append((target_node, target_edge))
else: else:
logger.warning( logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node" f"Skipping edge {edge_id} due to missing labels on target node"
) )
await traverse(label, 0) # Process nodes after releasing database connection
for target_node, target_edge in nodes_to_process:
await traverse(target_node, target_edge, current_depth + 1)
# Get the starting node's data
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
MATCH (n:`{label}`)
RETURN id(n) as node_id, n
"""
node_result = await session.run(query)
try:
node_record = await node_result.single()
if not node_record:
return result
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=str(node_record["node_id"]),
labels=list(node_record["n"].labels),
properties=dict(node_record["n"]),
)
finally:
await node_result.consume() # Ensure results are consumed
# Start traversal with the initial node
await traverse(start_node, None, 0)
return result return result
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]: