Add is_truncated to graph query for Neo4j
This commit is contained in:
@@ -658,7 +658,8 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeGraph object containing nodes and edges
|
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||||
|
indicating whether the graph was truncated due to max_nodes limit
|
||||||
"""
|
"""
|
||||||
result = KnowledgeGraph()
|
result = KnowledgeGraph()
|
||||||
seen_nodes = set()
|
seen_nodes = set()
|
||||||
@@ -669,6 +670,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
) as session:
|
) as session:
|
||||||
try:
|
try:
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
|
# First check total node count to determine if graph is truncated
|
||||||
|
count_query = "MATCH (n) RETURN count(n) as total"
|
||||||
|
count_result = None
|
||||||
|
try:
|
||||||
|
count_result = await session.run(count_query)
|
||||||
|
count_record = await count_result.single()
|
||||||
|
|
||||||
|
if count_record and count_record["total"] > max_nodes:
|
||||||
|
result.is_truncated = True
|
||||||
|
logger.info(
|
||||||
|
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if count_result:
|
||||||
|
await count_result.consume()
|
||||||
|
|
||||||
|
# Run main query to get nodes with highest degree
|
||||||
main_query = """
|
main_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
OPTIONAL MATCH (n)-[r]-()
|
OPTIONAL MATCH (n)-[r]-()
|
||||||
@@ -683,14 +701,71 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
RETURN filtered_nodes AS node_info,
|
RETURN filtered_nodes AS node_info,
|
||||||
collect(DISTINCT r) AS relationships
|
collect(DISTINCT r) AS relationships
|
||||||
"""
|
"""
|
||||||
|
result_set = None
|
||||||
|
try:
|
||||||
result_set = await session.run(
|
result_set = await session.run(
|
||||||
main_query,
|
main_query,
|
||||||
{"max_nodes": max_nodes},
|
{"max_nodes": max_nodes},
|
||||||
)
|
)
|
||||||
|
record = await result_set.single()
|
||||||
|
finally:
|
||||||
|
if result_set:
|
||||||
|
await result_set.consume()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Main query uses partial matching
|
# First try without limit to check if we need to truncate
|
||||||
main_query = """
|
full_query = """
|
||||||
|
MATCH (start)
|
||||||
|
WHERE start.entity_id = $entity_id
|
||||||
|
WITH start
|
||||||
|
CALL apoc.path.subgraphAll(start, {
|
||||||
|
relationshipFilter: '',
|
||||||
|
minLevel: 0,
|
||||||
|
maxLevel: $max_depth,
|
||||||
|
bfs: true
|
||||||
|
})
|
||||||
|
YIELD nodes, relationships
|
||||||
|
WITH nodes, relationships, size(nodes) AS total_nodes
|
||||||
|
UNWIND nodes AS node
|
||||||
|
WITH collect({node: node}) AS node_info, relationships, total_nodes
|
||||||
|
RETURN node_info, relationships, total_nodes
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Try to get full result
|
||||||
|
full_result = None
|
||||||
|
try:
|
||||||
|
full_result = await session.run(
|
||||||
|
full_query,
|
||||||
|
{
|
||||||
|
"entity_id": node_label,
|
||||||
|
"max_depth": max_depth,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
full_record = await full_result.single()
|
||||||
|
|
||||||
|
# If no record found, return empty KnowledgeGraph
|
||||||
|
if not full_record:
|
||||||
|
logger.debug(f"No nodes found for entity_id: {node_label}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If record found, check node count
|
||||||
|
total_nodes = full_record["total_nodes"]
|
||||||
|
|
||||||
|
if total_nodes <= max_nodes:
|
||||||
|
# If node count is within limit, use full result directly
|
||||||
|
logger.debug(
|
||||||
|
f"Using full result with {total_nodes} nodes (no truncation needed)"
|
||||||
|
)
|
||||||
|
record = full_record
|
||||||
|
else:
|
||||||
|
# If node count exceeds limit, set truncated flag and run limited query
|
||||||
|
result.is_truncated = True
|
||||||
|
logger.info(
|
||||||
|
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run limited query
|
||||||
|
limited_query = """
|
||||||
MATCH (start)
|
MATCH (start)
|
||||||
WHERE start.entity_id = $entity_id
|
WHERE start.entity_id = $entity_id
|
||||||
WITH start
|
WITH start
|
||||||
@@ -706,17 +781,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
WITH collect({node: node}) AS node_info, relationships
|
WITH collect({node: node}) AS node_info, relationships
|
||||||
RETURN node_info, relationships
|
RETURN node_info, relationships
|
||||||
"""
|
"""
|
||||||
|
result_set = None
|
||||||
|
try:
|
||||||
result_set = await session.run(
|
result_set = await session.run(
|
||||||
main_query,
|
limited_query,
|
||||||
{
|
{
|
||||||
"entity_id": node_label,
|
"entity_id": node_label,
|
||||||
"max_depth": max_depth,
|
"max_depth": max_depth,
|
||||||
"max_nodes": max_nodes,
|
"max_nodes": max_nodes,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
record = await result_set.single()
|
record = await result_set.single()
|
||||||
|
finally:
|
||||||
|
if result_set:
|
||||||
|
await result_set.consume()
|
||||||
|
finally:
|
||||||
|
if full_result:
|
||||||
|
await full_result.consume()
|
||||||
|
|
||||||
if record:
|
if record:
|
||||||
# Handle nodes (compatible with multi-label cases)
|
# Handle nodes (compatible with multi-label cases)
|
||||||
@@ -753,8 +834,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
|
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
await result_set.consume() # Ensure result set is consumed
|
|
||||||
|
|
||||||
except neo4jExceptions.ClientError as e:
|
except neo4jExceptions.ClientError as e:
|
||||||
logger.warning(f"APOC plugin error: {str(e)}")
|
logger.warning(f"APOC plugin error: {str(e)}")
|
||||||
@@ -763,6 +842,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
"Neo4j: falling back to basic Cypher recursive search..."
|
"Neo4j: falling back to basic Cypher recursive search..."
|
||||||
)
|
)
|
||||||
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
return await self._robust_fallback(node_label, max_depth, max_nodes)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Neo4j: APOC plugin error with wildcard query, returning empty result"
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -788,7 +871,11 @@ class Neo4JStorage(BaseGraphStorage):
|
|||||||
logger.debug(f"Reached max depth: {max_depth}")
|
logger.debug(f"Reached max depth: {max_depth}")
|
||||||
return
|
return
|
||||||
if len(visited_nodes) >= max_nodes:
|
if len(visited_nodes) >= max_nodes:
|
||||||
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
|
# Set truncated flag when we hit the max_nodes limit
|
||||||
|
result.is_truncated = True
|
||||||
|
logger.info(
|
||||||
|
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if node already visited
|
# Check if node already visited
|
||||||
|
Reference in New Issue
Block a user