Add is_truncated to graph query for Neo4j

This commit is contained in:
yangdx
2025-04-02 23:20:07 +08:00
parent 72132ee1d6
commit c339f8686a

View File

@@ -658,7 +658,8 @@ class Neo4JStorage(BaseGraphStorage):
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
""" """
result = KnowledgeGraph() result = KnowledgeGraph()
seen_nodes = set() seen_nodes = set()
@@ -669,6 +670,23 @@ class Neo4JStorage(BaseGraphStorage):
) as session: ) as session:
try: try:
if node_label == "*": if node_label == "*":
# First check total node count to determine if graph is truncated
count_query = "MATCH (n) RETURN count(n) as total"
count_result = None
try:
count_result = await session.run(count_query)
count_record = await count_result.single()
if count_record and count_record["total"] > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
)
finally:
if count_result:
await count_result.consume()
# Run main query to get nodes with highest degree
main_query = """ main_query = """
MATCH (n) MATCH (n)
OPTIONAL MATCH (n)-[r]-() OPTIONAL MATCH (n)-[r]-()
@@ -683,14 +701,20 @@ class Neo4JStorage(BaseGraphStorage):
RETURN filtered_nodes AS node_info, RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships collect(DISTINCT r) AS relationships
""" """
result_set = await session.run( result_set = None
main_query, try:
{"max_nodes": max_nodes}, result_set = await session.run(
) main_query,
{"max_nodes": max_nodes},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
else: else:
# Main query uses partial matching # First try without limit to check if we need to truncate
main_query = """ full_query = """
MATCH (start) MATCH (start)
WHERE start.entity_id = $entity_id WHERE start.entity_id = $entity_id
WITH start WITH start
@@ -698,63 +722,118 @@ class Neo4JStorage(BaseGraphStorage):
relationshipFilter: '', relationshipFilter: '',
minLevel: 0, minLevel: 0,
maxLevel: $max_depth, maxLevel: $max_depth,
limit: $max_nodes,
bfs: true bfs: true
}) })
YIELD nodes, relationships YIELD nodes, relationships
WITH nodes, relationships, size(nodes) AS total_nodes
UNWIND nodes AS node UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships WITH collect({node: node}) AS node_info, relationships, total_nodes
RETURN node_info, relationships RETURN node_info, relationships, total_nodes
""" """
result_set = await session.run(
main_query,
{
"entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
},
)
try: # Try to get full result
record = await result_set.single() full_result = None
try:
if record: full_result = await session.run(
# Handle nodes (compatible with multi-label cases) full_query,
for node_info in record["node_info"]: {
node = node_info["node"] "entity_id": node_label,
node_id = node.id "max_depth": max_depth,
if node_id not in seen_nodes: },
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
) )
finally: full_record = await full_result.single()
await result_set.consume() # Ensure result set is consumed
# If no record found, return empty KnowledgeGraph
if not full_record:
logger.debug(f"No nodes found for entity_id: {node_label}")
return result
# If record found, check node count
total_nodes = full_record["total_nodes"]
if total_nodes <= max_nodes:
# If node count is within limit, use full result directly
logger.debug(
f"Using full result with {total_nodes} nodes (no truncation needed)"
)
record = full_record
else:
# If node count exceeds limit, set truncated flag and run limited query
result.is_truncated = True
logger.info(
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
)
# Run limited query
limited_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
minLevel: 0,
maxLevel: $max_depth,
limit: $max_nodes,
bfs: true
})
YIELD nodes, relationships
UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships
RETURN node_info, relationships
"""
result_set = None
try:
result_set = await session.run(
limited_query,
{
"entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
finally:
if full_result:
await full_result.consume()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
)
except neo4jExceptions.ClientError as e: except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}") logger.warning(f"APOC plugin error: {str(e)}")
@@ -763,6 +842,10 @@ class Neo4JStorage(BaseGraphStorage):
"Neo4j: falling back to basic Cypher recursive search..." "Neo4j: falling back to basic Cypher recursive search..."
) )
return await self._robust_fallback(node_label, max_depth, max_nodes) return await self._robust_fallback(node_label, max_depth, max_nodes)
else:
logger.warning(
"Neo4j: APOC plugin error with wildcard query, returning empty result"
)
return result return result
@@ -788,7 +871,11 @@ class Neo4JStorage(BaseGraphStorage):
logger.debug(f"Reached max depth: {max_depth}") logger.debug(f"Reached max depth: {max_depth}")
return return
if len(visited_nodes) >= max_nodes: if len(visited_nodes) >= max_nodes:
logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") # Set truncated flag when we hit the max_nodes limit
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
return return
# Check if node already visited # Check if node already visited