Refactor Neo4J APOC fall back retrival implementaion
This commit is contained in:
@@ -3,7 +3,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Dict, final
|
from typing import Any, final, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
@@ -304,7 +304,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
@@ -321,60 +320,59 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
result = await session.run(query)
|
result = await session.run(query)
|
||||||
try:
|
records = await result.fetch(2) # Get up to 2 records to check for duplicates
|
||||||
records = await result.fetch(2) # Get up to 2 records to check for duplicates
|
await result.consume() # Ensure result is fully consumed before processing records
|
||||||
if len(records) > 1:
|
|
||||||
logger.warning(
|
if len(records) > 1:
|
||||||
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
|
logger.warning(
|
||||||
)
|
f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
|
||||||
if records:
|
|
||||||
try:
|
|
||||||
result = dict(records[0]["edge_properties"])
|
|
||||||
logger.debug(f"Result: {result}")
|
|
||||||
# Ensure required keys exist with defaults
|
|
||||||
required_keys = {
|
|
||||||
"weight": 0.0,
|
|
||||||
"source_id": None,
|
|
||||||
"description": None,
|
|
||||||
"keywords": None,
|
|
||||||
}
|
|
||||||
for key, default_value in required_keys.items():
|
|
||||||
if key not in result:
|
|
||||||
result[key] = default_value
|
|
||||||
logger.warning(
|
|
||||||
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
|
||||||
f"missing {key}, using default: {default_value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
except (KeyError, TypeError, ValueError) as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error processing edge properties between {entity_name_label_source} "
|
|
||||||
f"and {entity_name_label_target}: {str(e)}"
|
|
||||||
)
|
|
||||||
# Return default edge properties on error
|
|
||||||
return {
|
|
||||||
"weight": 0.0,
|
|
||||||
"description": None,
|
|
||||||
"keywords": None,
|
|
||||||
"source_id": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
|
||||||
)
|
)
|
||||||
# Return default edge properties when no edge found
|
if records:
|
||||||
return {
|
try:
|
||||||
"weight": 0.0,
|
edge_result = dict(records[0]["edge_properties"])
|
||||||
"description": None,
|
logger.debug(f"Result: {edge_result}")
|
||||||
"keywords": None,
|
# Ensure required keys exist with defaults
|
||||||
"source_id": None,
|
required_keys = {
|
||||||
}
|
"weight": 0.0,
|
||||||
finally:
|
"source_id": None,
|
||||||
await result.consume() # Ensure result is fully consumed
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
|
}
|
||||||
|
for key, default_value in required_keys.items():
|
||||||
|
if key not in edge_result:
|
||||||
|
edge_result[key] = default_value
|
||||||
|
logger.warning(
|
||||||
|
f"Edge between {entity_name_label_source} and {entity_name_label_target} "
|
||||||
|
f"missing {key}, using default: {default_value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
|
||||||
|
)
|
||||||
|
return edge_result
|
||||||
|
except (KeyError, TypeError, ValueError) as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing edge properties between {entity_name_label_source} "
|
||||||
|
f"and {entity_name_label_target}: {str(e)}"
|
||||||
|
)
|
||||||
|
# Return default edge properties on error
|
||||||
|
return {
|
||||||
|
"weight": 0.0,
|
||||||
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
|
"source_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
|
||||||
|
)
|
||||||
|
# Return default edge properties when no edge found
|
||||||
|
return {
|
||||||
|
"weight": 0.0,
|
||||||
|
"description": None,
|
||||||
|
"keywords": None,
|
||||||
|
"source_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -685,30 +683,36 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
await result_set.consume() # Ensure result set is consumed
|
await result_set.consume() # Ensure result set is consumed
|
||||||
|
|
||||||
except neo4jExceptions.ClientError as e:
|
except neo4jExceptions.ClientError as e:
|
||||||
logger.warning(
|
logger.warning(f"APOC plugin error: {str(e)}")
|
||||||
f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation"
|
if label != "*":
|
||||||
)
|
|
||||||
if inclusive:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Inclusive search mode is not supported in recursive query, using exact matching"
|
"Neo4j: falling back to basic Cypher recursive search..."
|
||||||
)
|
)
|
||||||
return await self._robust_fallback(label, max_depth, min_degree)
|
if inclusive:
|
||||||
|
logger.warning(
|
||||||
|
"Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
|
||||||
|
)
|
||||||
|
return await self._robust_fallback(label, max_depth, min_degree)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _robust_fallback(
|
async def _robust_fallback(
|
||||||
self, label: str, max_depth: int, min_degree: int = 0
|
self, label: str, max_depth: int, min_degree: int = 0
|
||||||
) -> Dict[str, List[Dict]]:
|
) -> KnowledgeGraph:
|
||||||
"""
|
"""
|
||||||
Fallback implementation when APOC plugin is not available or incompatible.
|
Fallback implementation when APOC plugin is not available or incompatible.
|
||||||
This method implements the same functionality as get_knowledge_graph but uses
|
This method implements the same functionality as get_knowledge_graph but uses
|
||||||
only basic Cypher queries and recursive traversal instead of APOC procedures.
|
only basic Cypher queries and recursive traversal instead of APOC procedures.
|
||||||
"""
|
"""
|
||||||
result = {"nodes": [], "edges": []}
|
result = KnowledgeGraph()
|
||||||
visited_nodes = set()
|
visited_nodes = set()
|
||||||
visited_edges = set()
|
visited_edges = set()
|
||||||
|
|
||||||
async def traverse(current_label: str, current_depth: int):
|
async def traverse(
|
||||||
|
node: KnowledgeGraphNode,
|
||||||
|
edge: Optional[KnowledgeGraphEdge],
|
||||||
|
current_depth: int,
|
||||||
|
):
|
||||||
# Check traversal limits
|
# Check traversal limits
|
||||||
if current_depth > max_depth:
|
if current_depth > max_depth:
|
||||||
logger.debug(f"Reached max depth: {max_depth}")
|
logger.debug(f"Reached max depth: {max_depth}")
|
||||||
@@ -717,62 +721,101 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
|
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get current node details
|
# Check if node already visited
|
||||||
node = await self.get_node(current_label)
|
if node.id in visited_nodes:
|
||||||
if not node:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
node_id = f"{current_label}"
|
# Get all edges and target nodes
|
||||||
if node_id in visited_nodes:
|
|
||||||
return
|
|
||||||
visited_nodes.add(node_id)
|
|
||||||
|
|
||||||
# Add node data with label as ID
|
|
||||||
result["nodes"].append(
|
|
||||||
{"id": current_label, "labels": current_label, "properties": node}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get connected nodes that meet the degree requirement
|
|
||||||
# Note: We don't need to check a's degree since it's the current node
|
|
||||||
# and was already validated in the previous iteration
|
|
||||||
query = f"""
|
|
||||||
MATCH (a:`{current_label}`)-[r]-(b)
|
|
||||||
WITH r, b,
|
|
||||||
COUNT((b)--()) AS b_degree
|
|
||||||
WHERE b_degree >= $min_degree OR EXISTS((a)--(b))
|
|
||||||
RETURN r, b
|
|
||||||
"""
|
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
) as session:
|
) as session:
|
||||||
results = await session.run(query, {"min_degree": min_degree})
|
query = """
|
||||||
async for record in results:
|
MATCH (a)-[r]-(b)
|
||||||
# Handle edges
|
WHERE id(a) = toInteger($node_id)
|
||||||
|
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||||
|
RETURN r, b, edge_id, target_id
|
||||||
|
"""
|
||||||
|
results = await session.run(query, {"node_id": node.id})
|
||||||
|
|
||||||
|
# Get all records and release database connection
|
||||||
|
records = await results.fetch()
|
||||||
|
await results.consume() # Ensure results are consumed
|
||||||
|
|
||||||
|
# Nodes not connected to start node need to check degree
|
||||||
|
if current_depth > 1 and len(records) < min_degree:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add current node to result
|
||||||
|
result.nodes.append(node)
|
||||||
|
visited_nodes.add(node.id)
|
||||||
|
|
||||||
|
# Add edge to result if it exists and not already added
|
||||||
|
if edge and edge.id not in visited_edges:
|
||||||
|
result.edges.append(edge)
|
||||||
|
visited_edges.add(edge.id)
|
||||||
|
|
||||||
|
# Prepare nodes and edges for recursive processing
|
||||||
|
nodes_to_process = []
|
||||||
|
for record in records:
|
||||||
rel = record["r"]
|
rel = record["r"]
|
||||||
edge_id = f"{rel.id}_{rel.type}"
|
edge_id = str(record["edge_id"])
|
||||||
if edge_id not in visited_edges:
|
if edge_id not in visited_edges:
|
||||||
b_node = record["b"]
|
b_node = record["b"]
|
||||||
if b_node.labels: # Only process if target node has labels
|
target_id = str(record["target_id"])
|
||||||
target_label = list(b_node.labels)[0]
|
|
||||||
result["edges"].append(
|
|
||||||
{
|
|
||||||
"id": f"{current_label}_{target_label}",
|
|
||||||
"type": rel.type,
|
|
||||||
"source": current_label,
|
|
||||||
"target": target_label,
|
|
||||||
"properties": dict(rel),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
visited_edges.add(edge_id)
|
|
||||||
|
|
||||||
# Continue traversal
|
if b_node.labels: # Only process if target node has labels
|
||||||
await traverse(target_label, current_depth + 1)
|
# Create KnowledgeGraphNode for target
|
||||||
|
target_node = KnowledgeGraphNode(
|
||||||
|
id=target_id,
|
||||||
|
labels=list(b_node.labels),
|
||||||
|
properties=dict(b_node),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create KnowledgeGraphEdge
|
||||||
|
target_edge = KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=rel.type,
|
||||||
|
source=node.id,
|
||||||
|
target=target_id,
|
||||||
|
properties=dict(rel),
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes_to_process.append((target_node, target_edge))
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Skipping edge {edge_id} due to missing labels on target node"
|
f"Skipping edge {edge_id} due to missing labels on target node"
|
||||||
)
|
)
|
||||||
|
|
||||||
await traverse(label, 0)
|
# Process nodes after releasing database connection
|
||||||
|
for target_node, target_edge in nodes_to_process:
|
||||||
|
await traverse(target_node, target_edge, current_depth + 1)
|
||||||
|
|
||||||
|
# Get the starting node's data
|
||||||
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
|
query = f"""
|
||||||
|
MATCH (n:`{label}`)
|
||||||
|
RETURN id(n) as node_id, n
|
||||||
|
"""
|
||||||
|
node_result = await session.run(query)
|
||||||
|
try:
|
||||||
|
node_record = await node_result.single()
|
||||||
|
if not node_record:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Create initial KnowledgeGraphNode
|
||||||
|
start_node = KnowledgeGraphNode(
|
||||||
|
id=str(node_record["node_id"]),
|
||||||
|
labels=list(node_record["n"].labels),
|
||||||
|
properties=dict(node_record["n"]),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await node_result.consume() # Ensure results are consumed
|
||||||
|
|
||||||
|
# Start traversal with the initial node
|
||||||
|
await traverse(start_node, None, 0)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
Reference in New Issue
Block a user