Fix APOC fallback error
This commit is contained in:
@@ -2,7 +2,7 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final, Optional
|
||||
from typing import Any, final
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
@@ -713,6 +713,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
await result_set.consume()
|
||||
|
||||
else:
|
||||
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
# First try without limit to check if we need to truncate
|
||||
full_query = """
|
||||
MATCH (start)
|
||||
@@ -855,98 +856,14 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"""
|
||||
Fallback implementation when APOC plugin is not available or incompatible.
|
||||
This method implements the same functionality as get_knowledge_graph but uses
|
||||
only basic Cypher queries and recursive traversal instead of APOC procedures.
|
||||
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
result = KnowledgeGraph()
|
||||
visited_nodes = set()
|
||||
visited_edges = set()
|
||||
|
||||
async def traverse(
|
||||
node: KnowledgeGraphNode,
|
||||
edge: Optional[KnowledgeGraphEdge],
|
||||
current_depth: int,
|
||||
):
|
||||
# Check traversal limits
|
||||
if current_depth > max_depth:
|
||||
logger.debug(f"Reached max depth: {max_depth}")
|
||||
return
|
||||
if len(visited_nodes) >= max_nodes:
|
||||
# Set truncated flag when we hit the max_nodes limit
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
||||
)
|
||||
return
|
||||
|
||||
# Check if node already visited
|
||||
if node.id in visited_nodes:
|
||||
return
|
||||
|
||||
# Get all edges and target nodes
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, entity_id=node.id)
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch(
|
||||
1000
|
||||
) # Max neighbour nodes we can handled
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Nodes not connected to start node need to check degree
|
||||
# if current_depth > 1 and len(records) < min_degree:
|
||||
# return
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(node)
|
||||
visited_nodes.add(node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if edge and edge.id not in visited_edges:
|
||||
result.edges.append(edge)
|
||||
visited_edges.add(edge.id)
|
||||
|
||||
# Prepare nodes and edges for recursive processing
|
||||
nodes_to_process = []
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
target_id = b_node.get("entity_id")
|
||||
|
||||
if target_id: # Only process if target node has entity_id
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=list(f"{target_id}"),
|
||||
properties=dict(b_node.properties),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{node.id}",
|
||||
target=f"{target_id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
nodes_to_process.append((target_node, target_edge))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing labels on target node"
|
||||
)
|
||||
|
||||
# Process nodes after releasing database connection
|
||||
for target_node, target_edge in nodes_to_process:
|
||||
await traverse(target_node, target_edge, current_depth + 1)
|
||||
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
|
||||
|
||||
# Get the starting node's data
|
||||
async with self._driver.session(
|
||||
@@ -965,15 +882,129 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Create initial KnowledgeGraphNode
|
||||
start_node = KnowledgeGraphNode(
|
||||
id=f"{node_record['n'].get('entity_id')}",
|
||||
labels=list(f"{node_record['n'].get('entity_id')}"),
|
||||
properties=dict(node_record["n"].properties),
|
||||
labels=[node_record["n"].get("entity_id")],
|
||||
properties=dict(node_record["n"]._properties),
|
||||
)
|
||||
finally:
|
||||
await node_result.consume() # Ensure results are consumed
|
||||
|
||||
# Start traversal with the initial node
|
||||
await traverse(start_node, None, 0)
|
||||
# Initialize queue for BFS with (node, edge, depth) tuples
|
||||
# edge is None for the starting node
|
||||
queue = deque([(start_node, None, 0)])
|
||||
|
||||
# True BFS implementation using a queue
|
||||
while queue and len(visited_nodes) < max_nodes:
|
||||
# Dequeue the next node to process
|
||||
current_node, current_edge, current_depth = queue.popleft()
|
||||
|
||||
# Skip if already visited or exceeds max depth
|
||||
if current_node.id in visited_nodes:
|
||||
continue
|
||||
|
||||
if current_depth > max_depth:
|
||||
logger.debug(
|
||||
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add current node to result
|
||||
result.nodes.append(current_node)
|
||||
visited_nodes.add(current_node.id)
|
||||
|
||||
# Add edge to result if it exists and not already added
|
||||
if current_edge and current_edge.id not in visited_edges:
|
||||
result.edges.append(current_edge)
|
||||
visited_edges.add(current_edge.id)
|
||||
|
||||
# Stop if we've reached the node limit
|
||||
if len(visited_nodes) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
||||
)
|
||||
break
|
||||
|
||||
# Get all edges and target nodes for the current node (even at max_depth)
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
"""
|
||||
results = await session.run(query, entity_id=current_node.id)
|
||||
|
||||
# Get all records and release database connection
|
||||
records = await results.fetch(1000) # Max neighbor nodes we can handle
|
||||
await results.consume() # Ensure results are consumed
|
||||
|
||||
# Process all neighbors - capture all edges but only queue unvisited nodes
|
||||
for record in records:
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
|
||||
if edge_id not in visited_edges:
|
||||
b_node = record["b"]
|
||||
target_id = b_node.get("entity_id")
|
||||
|
||||
if target_id: # Only process if target node has entity_id
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=f"{target_id}",
|
||||
labels=[target_id],
|
||||
properties=dict(b_node._properties),
|
||||
)
|
||||
|
||||
# Create KnowledgeGraphEdge
|
||||
target_edge = KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{current_node.id}",
|
||||
target=f"{target_id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
|
||||
# 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
|
||||
sorted_pair = tuple(sorted([current_node.id, target_id]))
|
||||
|
||||
# 检查是否已存在相同的边(考虑无向性)
|
||||
if sorted_pair not in visited_edge_pairs:
|
||||
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
|
||||
if target_id in visited_nodes or (
|
||||
target_id not in visited_nodes
|
||||
and current_depth < max_depth
|
||||
):
|
||||
result.edges.append(target_edge)
|
||||
visited_edges.add(edge_id)
|
||||
visited_edge_pairs.add(sorted_pair)
|
||||
|
||||
# Only add unvisited nodes to the queue for further expansion
|
||||
if target_id not in visited_nodes:
|
||||
# Only add to queue if we're not at max depth yet
|
||||
if current_depth < max_depth:
|
||||
# Add node to queue with incremented depth
|
||||
# Edge is already added to result, so we pass None as edge
|
||||
queue.append((target_node, None, current_depth + 1))
|
||||
else:
|
||||
# At max depth, we've already added the edge but we don't add the node
|
||||
# This prevents adding nodes beyond max_depth to the result
|
||||
logger.debug(
|
||||
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
|
||||
)
|
||||
else:
|
||||
# If target node already exists in result, we don't need to add it again
|
||||
logger.debug(
|
||||
f"Node {target_id} already visited, edge added but node not queued"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping edge {edge_id} due to missing entity_id on target node"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
|
Reference in New Issue
Block a user