Fix APOC fallback error
This commit is contained in:
@@ -2,7 +2,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, final, Optional
|
from typing import Any, final
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
@@ -713,6 +713,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
await result_set.consume()
|
await result_set.consume()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||||
# First try without limit to check if we need to truncate
|
# First try without limit to check if we need to truncate
|
||||||
full_query = """
|
full_query = """
|
||||||
MATCH (start)
|
MATCH (start)
|
||||||
@@ -855,98 +856,14 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
"""
|
"""
|
||||||
Fallback implementation when APOC plugin is not available or incompatible.
|
Fallback implementation when APOC plugin is not available or incompatible.
|
||||||
This method implements the same functionality as get_knowledge_graph but uses
|
This method implements the same functionality as get_knowledge_graph but uses
|
||||||
only basic Cypher queries and recursive traversal instead of APOC procedures.
|
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
|
||||||
"""
|
"""
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
result = KnowledgeGraph()
|
result = KnowledgeGraph()
|
||||||
visited_nodes = set()
|
visited_nodes = set()
|
||||||
visited_edges = set()
|
visited_edges = set()
|
||||||
|
visited_edge_pairs = set() # 用于跟踪已处理的边对(排序后的source_id, target_id)
|
||||||
async def traverse(
|
|
||||||
node: KnowledgeGraphNode,
|
|
||||||
edge: Optional[KnowledgeGraphEdge],
|
|
||||||
current_depth: int,
|
|
||||||
):
|
|
||||||
# Check traversal limits
|
|
||||||
if current_depth > max_depth:
|
|
||||||
logger.debug(f"Reached max depth: {max_depth}")
|
|
||||||
return
|
|
||||||
if len(visited_nodes) >= max_nodes:
|
|
||||||
# Set truncated flag when we hit the max_nodes limit
|
|
||||||
result.is_truncated = True
|
|
||||||
logger.info(
|
|
||||||
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if node already visited
|
|
||||||
if node.id in visited_nodes:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get all edges and target nodes
|
|
||||||
async with self._driver.session(
|
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
|
||||||
) as session:
|
|
||||||
query = """
|
|
||||||
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
|
||||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
|
||||||
RETURN r, b, edge_id, target_id
|
|
||||||
"""
|
|
||||||
results = await session.run(query, entity_id=node.id)
|
|
||||||
|
|
||||||
# Get all records and release database connection
|
|
||||||
records = await results.fetch(
|
|
||||||
1000
|
|
||||||
) # Max neighbour nodes we can handled
|
|
||||||
await results.consume() # Ensure results are consumed
|
|
||||||
|
|
||||||
# Nodes not connected to start node need to check degree
|
|
||||||
# if current_depth > 1 and len(records) < min_degree:
|
|
||||||
# return
|
|
||||||
|
|
||||||
# Add current node to result
|
|
||||||
result.nodes.append(node)
|
|
||||||
visited_nodes.add(node.id)
|
|
||||||
|
|
||||||
# Add edge to result if it exists and not already added
|
|
||||||
if edge and edge.id not in visited_edges:
|
|
||||||
result.edges.append(edge)
|
|
||||||
visited_edges.add(edge.id)
|
|
||||||
|
|
||||||
# Prepare nodes and edges for recursive processing
|
|
||||||
nodes_to_process = []
|
|
||||||
for record in records:
|
|
||||||
rel = record["r"]
|
|
||||||
edge_id = str(record["edge_id"])
|
|
||||||
if edge_id not in visited_edges:
|
|
||||||
b_node = record["b"]
|
|
||||||
target_id = b_node.get("entity_id")
|
|
||||||
|
|
||||||
if target_id: # Only process if target node has entity_id
|
|
||||||
# Create KnowledgeGraphNode for target
|
|
||||||
target_node = KnowledgeGraphNode(
|
|
||||||
id=f"{target_id}",
|
|
||||||
labels=list(f"{target_id}"),
|
|
||||||
properties=dict(b_node.properties),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create KnowledgeGraphEdge
|
|
||||||
target_edge = KnowledgeGraphEdge(
|
|
||||||
id=f"{edge_id}",
|
|
||||||
type=rel.type,
|
|
||||||
source=f"{node.id}",
|
|
||||||
target=f"{target_id}",
|
|
||||||
properties=dict(rel),
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes_to_process.append((target_node, target_edge))
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Skipping edge {edge_id} due to missing labels on target node"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process nodes after releasing database connection
|
|
||||||
for target_node, target_edge in nodes_to_process:
|
|
||||||
await traverse(target_node, target_edge, current_depth + 1)
|
|
||||||
|
|
||||||
# Get the starting node's data
|
# Get the starting node's data
|
||||||
async with self._driver.session(
|
async with self._driver.session(
|
||||||
@@ -965,15 +882,129 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
# Create initial KnowledgeGraphNode
|
# Create initial KnowledgeGraphNode
|
||||||
start_node = KnowledgeGraphNode(
|
start_node = KnowledgeGraphNode(
|
||||||
id=f"{node_record['n'].get('entity_id')}",
|
id=f"{node_record['n'].get('entity_id')}",
|
||||||
labels=list(f"{node_record['n'].get('entity_id')}"),
|
labels=[node_record["n"].get("entity_id")],
|
||||||
properties=dict(node_record["n"].properties),
|
properties=dict(node_record["n"]._properties),
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
await node_result.consume() # Ensure results are consumed
|
await node_result.consume() # Ensure results are consumed
|
||||||
|
|
||||||
# Start traversal with the initial node
|
# Initialize queue for BFS with (node, edge, depth) tuples
|
||||||
await traverse(start_node, None, 0)
|
# edge is None for the starting node
|
||||||
|
queue = deque([(start_node, None, 0)])
|
||||||
|
|
||||||
|
# True BFS implementation using a queue
|
||||||
|
while queue and len(visited_nodes) < max_nodes:
|
||||||
|
# Dequeue the next node to process
|
||||||
|
current_node, current_edge, current_depth = queue.popleft()
|
||||||
|
|
||||||
|
# Skip if already visited or exceeds max depth
|
||||||
|
if current_node.id in visited_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_depth > max_depth:
|
||||||
|
logger.debug(
|
||||||
|
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add current node to result
|
||||||
|
result.nodes.append(current_node)
|
||||||
|
visited_nodes.add(current_node.id)
|
||||||
|
|
||||||
|
# Add edge to result if it exists and not already added
|
||||||
|
if current_edge and current_edge.id not in visited_edges:
|
||||||
|
result.edges.append(current_edge)
|
||||||
|
visited_edges.add(current_edge.id)
|
||||||
|
|
||||||
|
# Stop if we've reached the node limit
|
||||||
|
if len(visited_nodes) >= max_nodes:
|
||||||
|
result.is_truncated = True
|
||||||
|
logger.info(
|
||||||
|
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get all edges and target nodes for the current node (even at max_depth)
|
||||||
|
async with self._driver.session(
|
||||||
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
|
) as session:
|
||||||
|
query = """
|
||||||
|
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
|
||||||
|
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||||
|
RETURN r, b, edge_id, target_id
|
||||||
|
"""
|
||||||
|
results = await session.run(query, entity_id=current_node.id)
|
||||||
|
|
||||||
|
# Get all records and release database connection
|
||||||
|
records = await results.fetch(1000) # Max neighbor nodes we can handle
|
||||||
|
await results.consume() # Ensure results are consumed
|
||||||
|
|
||||||
|
# Process all neighbors - capture all edges but only queue unvisited nodes
|
||||||
|
for record in records:
|
||||||
|
rel = record["r"]
|
||||||
|
edge_id = str(record["edge_id"])
|
||||||
|
|
||||||
|
if edge_id not in visited_edges:
|
||||||
|
b_node = record["b"]
|
||||||
|
target_id = b_node.get("entity_id")
|
||||||
|
|
||||||
|
if target_id: # Only process if target node has entity_id
|
||||||
|
# Create KnowledgeGraphNode for target
|
||||||
|
target_node = KnowledgeGraphNode(
|
||||||
|
id=f"{target_id}",
|
||||||
|
labels=[target_id],
|
||||||
|
properties=dict(b_node._properties),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create KnowledgeGraphEdge
|
||||||
|
target_edge = KnowledgeGraphEdge(
|
||||||
|
id=f"{edge_id}",
|
||||||
|
type=rel.type,
|
||||||
|
source=f"{current_node.id}",
|
||||||
|
target=f"{target_id}",
|
||||||
|
properties=dict(rel),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对source_id和target_id进行排序,确保(A,B)和(B,A)被视为同一条边
|
||||||
|
sorted_pair = tuple(sorted([current_node.id, target_id]))
|
||||||
|
|
||||||
|
# 检查是否已存在相同的边(考虑无向性)
|
||||||
|
if sorted_pair not in visited_edge_pairs:
|
||||||
|
# 只有当目标节点已经在结果中或将被添加到结果中时,才添加边
|
||||||
|
if target_id in visited_nodes or (
|
||||||
|
target_id not in visited_nodes
|
||||||
|
and current_depth < max_depth
|
||||||
|
):
|
||||||
|
result.edges.append(target_edge)
|
||||||
|
visited_edges.add(edge_id)
|
||||||
|
visited_edge_pairs.add(sorted_pair)
|
||||||
|
|
||||||
|
# Only add unvisited nodes to the queue for further expansion
|
||||||
|
if target_id not in visited_nodes:
|
||||||
|
# Only add to queue if we're not at max depth yet
|
||||||
|
if current_depth < max_depth:
|
||||||
|
# Add node to queue with incremented depth
|
||||||
|
# Edge is already added to result, so we pass None as edge
|
||||||
|
queue.append((target_node, None, current_depth + 1))
|
||||||
|
else:
|
||||||
|
# At max depth, we've already added the edge but we don't add the node
|
||||||
|
# This prevents adding nodes beyond max_depth to the result
|
||||||
|
logger.debug(
|
||||||
|
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If target node already exists in result, we don't need to add it again
|
||||||
|
logger.debug(
|
||||||
|
f"Node {target_id} already visited, edge added but node not queued"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Skipping edge {edge_id} due to missing entity_id on target node"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
|
Reference in New Issue
Block a user