Add is_truncated to graph query for Neo4j

This commit is contained in:
yangdx
2025-04-02 23:20:07 +08:00
parent 72132ee1d6
commit c339f8686a

View File

@@ -658,7 +658,8 @@ class Neo4JStorage(BaseGraphStorage):
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
""" """
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
@@ -669,6 +670,23 @@ class Neo4JStorage(BaseGraphStorage):
) as session: ) as session:
try: try:
if node_label == "*": if node_label == "*":
# First check total node count to determine if graph is truncated
count_query = "MATCH (n) RETURN count(n) as total"
count_result = None
try:
count_result = await session.run(count_query)
count_record = await count_result.single()
if count_record and count_record["total"] > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
)
finally:
if count_result:
await count_result.consume()
# Run main query to get nodes with highest degree
main_query = """ main_query = """
MATCH (n) MATCH (n)
OPTIONAL MATCH (n)-[r]-() OPTIONAL MATCH (n)-[r]-()
@@ -683,14 +701,71 @@ class Neo4JStorage(BaseGraphStorage):
RETURN filtered_nodes AS node_info, RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships collect(DISTINCT r) AS relationships
""" """
result_set = None
try:
result_set = await session.run( result_set = await session.run(
main_query, main_query,
{"max_nodes": max_nodes}, {"max_nodes": max_nodes},
) )
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
else: else:
# Main query uses partial matching # First try without limit to check if we need to truncate
main_query = """ full_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
minLevel: 0,
maxLevel: $max_depth,
bfs: true
})
YIELD nodes, relationships
WITH nodes, relationships, size(nodes) AS total_nodes
UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships, total_nodes
RETURN node_info, relationships, total_nodes
"""
# Try to get full result
full_result = None
try:
full_result = await session.run(
full_query,
{
"entity_id": node_label,
"max_depth": max_depth,
},
)
full_record = await full_result.single()
# If no record found, return empty KnowledgeGraph
if not full_record:
logger.debug(f"No nodes found for entity_id: {node_label}")
return result
# If record found, check node count
total_nodes = full_record["total_nodes"]
if total_nodes <= max_nodes:
# If node count is within limit, use full result directly
logger.debug(
f"Using full result with {total_nodes} nodes (no truncation needed)"
)
record = full_record
else:
# If node count exceeds limit, set truncated flag and run limited query
result.is_truncated = True
logger.info(
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
)
# Run limited query
limited_query = """
MATCH (start) MATCH (start)
WHERE start.entity_id = $entity_id WHERE start.entity_id = $entity_id
WITH start WITH start
@@ -706,17 +781,23 @@ class Neo4JStorage(BaseGraphStorage):
WITH collect({node: node}) AS node_info, relationships WITH collect({node: node}) AS node_info, relationships
RETURN node_info, relationships RETURN node_info, relationships
""" """
result_set = None
try:
result_set = await session.run( result_set = await session.run(
main_query, limited_query,
{ {
"entity_id": node_label, "entity_id": node_label,
"max_depth": max_depth, "max_depth": max_depth,
"max_nodes": max_nodes, "max_nodes": max_nodes,
}, },
) )
try:
record = await result_set.single() record = await result_set.single()
finally:
if result_set:
await result_set.consume()
finally:
if full_result:
await full_result.consume()
if record: if record:
# Handle nodes (compatible with multi-label cases) # Handle nodes (compatible with multi-label cases)
@@ -753,8 +834,6 @@ class Neo4JStorage(BaseGraphStorage):
logger.info( logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges" f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
) )
finally:
await result_set.consume() # Ensure result set is consumed
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}") logger.warning(f"APOC plugin error: {str(e)}")
@@ -763,6 +842,10 @@ class Neo4JStorage(BaseGraphStorage):
"Neo4j: falling back to basic Cypher recursive search..." "Neo4j: falling back to basic Cypher recursive search..."
) )
return await self._robust_fallback(node_label, max_depth, max_nodes) return await self._robust_fallback(node_label, max_depth, max_nodes)
else:
logger.warning(
"Neo4j: APOC plugin error with wildcard query, returning empty result"
)
return result return result
@@ -788,7 +871,11 @@ class Neo4JStorage(BaseGraphStorage):
logger.debug(f"Reached max depth: {max_depth}") logger.debug(f"Reached max depth: {max_depth}")
return return
if len(visited_nodes) >= max_nodes: if len(visited_nodes) >= max_nodes:
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") # Set truncated flag when we hit the max_nodes limit
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
return return
# Check if node already visited # Check if node already visited