Add is_truncate checking for PostgreSQL graph storage

This commit is contained in:
yangdx
2025-04-03 16:30:06 +08:00
parent 0826b0b80c
commit 9b71295309

View File

@@ -1479,20 +1479,37 @@ class PGGraphStorage(BaseGraphStorage):
Args:
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
# First, count the total number of nodes that would be returned without limit
if node_label == "*":
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base)
RETURN count(distinct n) AS total_nodes
$$) AS (total_nodes bigint)"""
else:
strip_label = node_label.strip('"')
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: "{strip_label}"}})
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
RETURN count(distinct m) AS total_nodes
$$) AS (total_nodes bigint)"""
# Build the query based on whether we want the full graph or a specific subgraph.
count_result = await self._query(count_query)
total_nodes = count_result[0]["total_nodes"] if count_result else 0
is_truncated = total_nodes > max_nodes
# Now get the actual data with limit
if node_label == "*":
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base)
OPTIONAL MATCH (n)-[r]->(target:base)
RETURN collect(distinct n) AS n, collect(distinct r) AS r
LIMIT {MAX_GRAPH_NODES}
LIMIT {max_nodes}
$$) AS (n agtype, r agtype)"""
else:
strip_label = node_label.strip('"')
@@ -1559,7 +1576,7 @@ class PGGraphStorage(BaseGraphStorage):
kg = KnowledgeGraph(
nodes=list(nodes_dict.values()),
edges=list(edges_dict.values()),
is_truncated=False,
is_truncated=is_truncated,
)
return kg