Add is_truncated to graph query for Neo4j
This commit is contained in:
@@ -658,7 +658,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
@@ -669,6 +670,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
) as session:
|
||||
try:
|
||||
if node_label == "*":
|
||||
# First check total node count to determine if graph is truncated
|
||||
count_query = "MATCH (n) RETURN count(n) as total"
|
||||
count_result = None
|
||||
try:
|
||||
count_result = await session.run(count_query)
|
||||
count_record = await count_result.single()
|
||||
|
||||
if count_record and count_record["total"] > max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
||||
)
|
||||
finally:
|
||||
if count_result:
|
||||
await count_result.consume()
|
||||
|
||||
# Run main query to get nodes with highest degree
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
OPTIONAL MATCH (n)-[r]-()
|
||||
@@ -683,14 +701,71 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
RETURN filtered_nodes AS node_info,
|
||||
collect(DISTINCT r) AS relationships
|
||||
"""
|
||||
result_set = None
|
||||
try:
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
{"max_nodes": max_nodes},
|
||||
)
|
||||
record = await result_set.single()
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
|
||||
else:
|
||||
# Main query uses partial matching
|
||||
main_query = """
|
||||
# First try without limit to check if we need to truncate
|
||||
full_query = """
|
||||
MATCH (start)
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {
|
||||
relationshipFilter: '',
|
||||
minLevel: 0,
|
||||
maxLevel: $max_depth,
|
||||
bfs: true
|
||||
})
|
||||
YIELD nodes, relationships
|
||||
WITH nodes, relationships, size(nodes) AS total_nodes
|
||||
UNWIND nodes AS node
|
||||
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
||||
RETURN node_info, relationships, total_nodes
|
||||
"""
|
||||
|
||||
# Try to get full result
|
||||
full_result = None
|
||||
try:
|
||||
full_result = await session.run(
|
||||
full_query,
|
||||
{
|
||||
"entity_id": node_label,
|
||||
"max_depth": max_depth,
|
||||
},
|
||||
)
|
||||
full_record = await full_result.single()
|
||||
|
||||
# If no record found, return empty KnowledgeGraph
|
||||
if not full_record:
|
||||
logger.debug(f"No nodes found for entity_id: {node_label}")
|
||||
return result
|
||||
|
||||
# If record found, check node count
|
||||
total_nodes = full_record["total_nodes"]
|
||||
|
||||
if total_nodes <= max_nodes:
|
||||
# If node count is within limit, use full result directly
|
||||
logger.debug(
|
||||
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
||||
)
|
||||
record = full_record
|
||||
else:
|
||||
# If node count exceeds limit, set truncated flag and run limited query
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
||||
)
|
||||
|
||||
# Run limited query
|
||||
limited_query = """
|
||||
MATCH (start)
|
||||
WHERE start.entity_id = $entity_id
|
||||
WITH start
|
||||
@@ -706,17 +781,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
WITH collect({node: node}) AS node_info, relationships
|
||||
RETURN node_info, relationships
|
||||
"""
|
||||
result_set = None
|
||||
try:
|
||||
result_set = await session.run(
|
||||
main_query,
|
||||
limited_query,
|
||||
{
|
||||
"entity_id": node_label,
|
||||
"max_depth": max_depth,
|
||||
"max_nodes": max_nodes,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
record = await result_set.single()
|
||||
finally:
|
||||
if result_set:
|
||||
await result_set.consume()
|
||||
finally:
|
||||
if full_result:
|
||||
await full_result.consume()
|
||||
|
||||
if record:
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
@@ -753,8 +834,6 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.info(
|
||||
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
|
||||
)
|
||||
finally:
|
||||
await result_set.consume() # Ensure result set is consumed
|
||||
|
||||
except neo4jExceptions.ClientError as e:
|
||||
logger.warning(f"APOC plugin error: {str(e)}")
|
||||
@@ -763,6 +842,10 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
"Neo4j: falling back to basic Cypher recursive search..."
|
||||
)
|
||||
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||
else:
|
||||
logger.warning(
|
||||
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -788,7 +871,11 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.debug(f"Reached max depth: {max_depth}")
|
||||
return
|
||||
if len(visited_nodes) >= max_nodes:
|
||||
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
|
||||
# Set truncated flag when we hit the max_nodes limit
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
||||
)
|
||||
return
|
||||
|
||||
# Check if node already visited
|
||||
|
Reference in New Issue
Block a user