Add is_truncated to graph query for Neo4j

This commit is contained in:
yangdx
2025-04-02 23:20:07 +08:00
parent 72132ee1d6
commit c339f8686a

View File

@@ -658,7 +658,8 @@ class Neo4JStorage(BaseGraphStorage):
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
result = KnowledgeGraph()
seen_nodes = set()
@@ -669,6 +670,23 @@ class Neo4JStorage(BaseGraphStorage):
) as session:
try:
if node_label == "*":
# First check total node count to determine if graph is truncated
count_query = "MATCH (n) RETURN count(n) as total"
count_result = None
try:
count_result = await session.run(count_query)
count_record = await count_result.single()
if count_record and count_record["total"] > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
)
finally:
if count_result:
await count_result.consume()
# Run main query to get nodes with highest degree
main_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
@@ -683,14 +701,20 @@ class Neo4JStorage(BaseGraphStorage):
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
"""
result_set = await session.run(
main_query,
{"max_nodes": max_nodes},
)
result_set = None
try:
result_set = await session.run(
main_query,
{"max_nodes": max_nodes},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
else:
# Main query uses partial matching
main_query = """
# First try without limit to check if we need to truncate
full_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
WITH start
@@ -698,63 +722,118 @@ class Neo4JStorage(BaseGraphStorage):
relationshipFilter: '',
minLevel: 0,
maxLevel: $max_depth,
limit: $max_nodes,
bfs: true
})
YIELD nodes, relationships
WITH nodes, relationships, size(nodes) AS total_nodes
UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships
RETURN node_info, relationships
WITH collect({node: node}) AS node_info, relationships, total_nodes
RETURN node_info, relationships, total_nodes
"""
result_set = await session.run(
main_query,
{
"entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
},
)
try:
record = await result_set.single()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
# Try to get full result
full_result = None
try:
full_result = await session.run(
full_query,
{
"entity_id": node_label,
"max_depth": max_depth,
},
)
finally:
await result_set.consume() # Ensure result set is consumed
full_record = await full_result.single()
# If no record found, return empty KnowledgeGraph
if not full_record:
logger.debug(f"No nodes found for entity_id: {node_label}")
return result
# If record found, check node count
total_nodes = full_record["total_nodes"]
if total_nodes <= max_nodes:
# If node count is within limit, use full result directly
logger.debug(
f"Using full result with {total_nodes} nodes (no truncation needed)"
)
record = full_record
else:
# If node count exceeds limit, set truncated flag and run limited query
result.is_truncated = True
logger.info(
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
)
# Run limited query
limited_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
minLevel: 0,
maxLevel: $max_depth,
limit: $max_nodes,
bfs: true
})
YIELD nodes, relationships
UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships
RETURN node_info, relationships
"""
result_set = None
try:
result_set = await session.run(
limited_query,
{
"entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
finally:
if full_result:
await full_result.consume()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
)
except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}")
@@ -763,6 +842,10 @@ class Neo4JStorage(BaseGraphStorage):
"Neo4j: falling back to basic Cypher recursive search..."
)
return await self._robust_fallback(node_label, max_depth, max_nodes)
else:
logger.warning(
"Neo4j: APOC plugin error with wildcard query, returning empty result"
)
return result
@@ -788,7 +871,11 @@ class Neo4JStorage(BaseGraphStorage):
logger.debug(f"Reached max depth: {max_depth}")
return
if len(visited_nodes) >= max_nodes:
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
# Set truncated flag when we hit the max_nodes limit
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
return
# Check if node already visited