Fix APOC fallback error

This commit is contained in:
yangdx
2025-04-04 02:05:29 +08:00
parent 399b2f14f6
commit b003537429

View File

@@ -2,7 +2,7 @@ import inspect
import os import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, final, Optional from typing import Any, final
import numpy as np import numpy as np
import configparser import configparser
@@ -713,6 +713,7 @@ class Neo4JStorage(BaseGraphStorage):
await result_set.consume() await result_set.consume()
else: else:
# return await self._robust_fallback(node_label, max_depth, max_nodes)
# First try without limit to check if we need to truncate # First try without limit to check if we need to truncate
full_query = """ full_query = """
MATCH (start) MATCH (start)
@@ -855,98 +856,14 @@ class Neo4JStorage(BaseGraphStorage):
""" """
Fallback implementation when APOC plugin is not available or incompatible. Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and recursive traversal instead of APOC procedures. only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
""" """
from collections import deque
result = KnowledgeGraph() result = KnowledgeGraph()
visited_nodes = set() visited_nodes = set()
visited_edges = set() visited_edges = set()
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
async def traverse(
node: KnowledgeGraphNode,
edge: Optional[KnowledgeGraphEdge],
current_depth: int,
):
# Check traversal limits
if current_depth > max_depth:
logger.debug(f"Reached max depth: {max_depth}")
return
if len(visited_nodes) >= max_nodes:
# Set truncated flag when we hit the max_nodes limit
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
return
# Check if node already visited
if node.id in visited_nodes:
return
# Get all edges and target nodes
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=node.id)
# Get all records and release database connection
records = await results.fetch(
1000
) # Max neighbour nodes we can handled
await results.consume() # Ensure results are consumed
# Nodes not connected to start node need to check degree
# if current_depth > 1 and len(records) < min_degree:
# return
# Add current node to result
result.nodes.append(node)
visited_nodes.add(node.id)
# Add edge to result if it exists and not already added
if edge and edge.id not in visited_edges:
result.edges.append(edge)
visited_edges.add(edge.id)
# Prepare nodes and edges for recursive processing
nodes_to_process = []
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=list(f"{target_id}"),
properties=dict(b_node.properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{node.id}",
target=f"{target_id}",
properties=dict(rel),
)
nodes_to_process.append((target_node, target_edge))
else:
logger.warning(
f"Skipping edge {edge_id} due to missing labels on target node"
)
# Process nodes after releasing database connection
for target_node, target_edge in nodes_to_process:
await traverse(target_node, target_edge, current_depth + 1)
# Get the starting node's data # Get the starting node's data
async with self._driver.session( async with self._driver.session(
@@ -965,15 +882,129 @@ class Neo4JStorage(BaseGraphStorage):
# Create initial KnowledgeGraphNode # Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode( start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}", id=f"{node_record['n'].get('entity_id')}",
labels=list(f"{node_record['n'].get('entity_id')}"), labels=[node_record["n"].get("entity_id")],
properties=dict(node_record["n"].properties), properties=dict(node_record["n"]._properties),
) )
finally: finally:
await node_result.consume() # Ensure results are consumed await node_result.consume() # Ensure results are consumed
# Start traversal with the initial node # Initialize queue for BFS with (node, edge, depth) tuples
await traverse(start_node, None, 0) # edge is None for the starting node
queue = deque([(start_node, None, 0)])
# True BFS implementation using a queue
while queue and len(visited_nodes) < max_nodes:
# Dequeue the next node to process
current_node, current_edge, current_depth = queue.popleft()
# Skip if already visited or exceeds max depth
if current_node.id in visited_nodes:
continue
if current_depth > max_depth:
logger.debug(
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
)
continue
# Add current node to result
result.nodes.append(current_node)
visited_nodes.add(current_node.id)
# Add edge to result if it exists and not already added
if current_edge and current_edge.id not in visited_edges:
result.edges.append(current_edge)
visited_edges.add(current_edge.id)
# Stop if we've reached the node limit
if len(visited_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
break
# Get all edges and target nodes for the current node (even at max_depth)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=current_node.id)
# Get all records and release database connection
records = await results.fetch(1000) # Max neighbor nodes we can handle
await results.consume() # Ensure results are consumed
# Process all neighbors - capture all edges but only queue unvisited nodes
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=[target_id],
properties=dict(b_node._properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{current_node.id}",
target=f"{target_id}",
properties=dict(rel),
)
# 对source_id和target_id进行排序确保(A,B)和(B,A)被视为同一条边
sorted_pair = tuple(sorted([current_node.id, target_id]))
# 检查是否已存在相同的边(考虑无向性)
if sorted_pair not in visited_edge_pairs:
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
if target_id in visited_nodes or (
target_id not in visited_nodes
and current_depth < max_depth
):
result.edges.append(target_edge)
visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair)
# Only add unvisited nodes to the queue for further expansion
if target_id not in visited_nodes:
# Only add to queue if we're not at max depth yet
if current_depth < max_depth:
# Add node to queue with incremented depth
# Edge is already added to result, so we pass None as edge
queue.append((target_node, None, current_depth + 1))
else:
# At max depth, we've already added the edge but we don't add the node
# This prevents adding nodes beyond max_depth to the result
logger.debug(
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
)
else:
# If target node already exists in result, we don't need to add it again
logger.debug(
f"Node {target_id} already visited, edge added but node not queued"
)
else:
logger.warning(
f"Skipping edge {edge_id} due to missing entity_id on target node"
)
logger.info(
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result return result
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]: