Fix APOC fallback error

This commit is contained in:
yangdx
2025-04-04 02:05:29 +08:00
parent 399b2f14f6
commit b003537429

View File

@@ -2,7 +2,7 @@ import inspect
import os
import re
from dataclasses import dataclass
from typing import Any, final, Optional
from typing import Any, final
import numpy as np
import configparser
@@ -713,6 +713,7 @@ class Neo4JStorage(BaseGraphStorage):
await result_set.consume()
else:
# return await self._robust_fallback(node_label, max_depth, max_nodes)
# First try without limit to check if we need to truncate
full_query = """
MATCH (start)
@@ -855,98 +856,14 @@ class Neo4JStorage(BaseGraphStorage):
"""
Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures.
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
"""
from collections import deque
result = KnowledgeGraph()
visited_nodes = set()
visited_edges = set()
async def traverse(
node: KnowledgeGraphNode,
edge: Optional[KnowledgeGraphEdge],
current_depth: int,
):
# Check traversal limits
if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}")
return
if len(visited_nodes) >= max_nodes:
# Set truncated flag when we hit the max_nodes limit
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
return
# Check if node already visited
if node.id in visited_nodes:
return
# Get all edges and target nodes
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=node.id)
# Get all records and release database connection
records = await results.fetch(
1000
) # Max neighbour nodes we can handled
await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree
# if current_depth > 1 and len(records) < min_degree:
# return
# Add current node to result
result.nodes.append(node)
visited_nodes.add(node.id)
# Add edge to result if it exists and not already added
if edge and edge.id not in visited_edges:
result.edges.append(edge)
visited_edges.add(edge.id)
# Prepare nodes and edges for recursive processing
nodes_to_process = []
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=list(f"{target_id}"),
properties=dict(b_node.properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{node.id}",
target=f"{target_id}",
properties=dict(rel),
)
nodes_to_process.append((target_node, target_edge))
else:
logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
# Process nodes after releasing database connection
for target_node, target_edge in nodes_to_process:
await traverse(target_node, target_edge, current_depth + 1)
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
# Get the starting node's data
async with self._driver.session(
@@ -965,15 +882,129 @@ class Neo4JStorage(BaseGraphStorage):
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}",
labels=list(f"{node_record['n'].get('entity_id')}"),
properties=dict(node_record["n"].properties),
labels=[node_record["n"].get("entity_id")],
properties=dict(node_record["n"]._properties),
)
finally:
await node_result.consume() # Ensure results are consumed
# Start traversal with the initial node
await traverse(start_node, None, 0)
# Initialize queue for BFS with (node, edge, depth) tuples
# edge is None for the starting node
queue = deque([(start_node, None, 0)])
# True BFS implementation using a queue
while queue and len(visited_nodes) < max_nodes:
# Dequeue the next node to process
current_node, current_edge, current_depth = queue.popleft()
# Skip if already visited or exceeds max depth
if current_node.id in visited_nodes:
continue
if current_depth > max_depth:
logger.debug(
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
)
continue
# Add current node to result
result.nodes.append(current_node)
visited_nodes.add(current_node.id)
# Add edge to result if it exists and not already added
if current_edge and current_edge.id not in visited_edges:
result.edges.append(current_edge)
visited_edges.add(current_edge.id)
# Stop if we've reached the node limit
if len(visited_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
break
# Get all edges and target nodes for the current node (even at max_depth)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=current_node.id)
# Get all records and release database connection
records = await results.fetch(1000) # Max neighbor nodes we can handle
await results.consume() # Ensure results are consumed
# Process all neighbors - capture all edges but only queue unvisited nodes
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=[target_id],
properties=dict(b_node._properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{current_node.id}",
target=f"{target_id}",
properties=dict(rel),
)
# 对source_id和target_id进行排序确保(A,B)和(B,A)被视为同一条边
sorted_pair = tuple(sorted([current_node.id, target_id]))
# 检查是否已存在相同的边(考虑无向性)
if sorted_pair not in visited_edge_pairs:
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
if target_id in visited_nodes or (
target_id not in visited_nodes
and current_depth < max_depth
):
result.edges.append(target_edge)
visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair)
# Only add unvisited nodes to the queue for further expansion
if target_id not in visited_nodes:
# Only add to queue if we're not at max depth yet
if current_depth < max_depth:
# Add node to queue with incremented depth
# Edge is already added to result, so we pass None as edge
queue.append((target_node, None, current_depth + 1))
else:
# At max depth, we've already added the edge but we don't add the node
# This prevents adding nodes beyond max_depth to the result
logger.debug(
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
)
else:
# If target node already exists in result, we don't need to add it again
logger.debug(
f"Node {target_id} already visited, edge added but node not queued"
)
else:
logger.warning(
f"Skipping edge {edge_id} due to missing entity_id on target node"
)
logger.info(
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def get_all_labels(self) -> list[str]: