Add is_truncate checking for PostgreSQL graph storage
This commit is contained in:
@@ -1479,20 +1479,37 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Args:
|
||||
node_label: Label of the starting node, * means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
||||
max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
# First, count the total number of nodes that would be returned without limit
|
||||
if node_label == "*":
|
||||
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base)
|
||||
RETURN count(distinct n) AS total_nodes
|
||||
$$) AS (total_nodes bigint)"""
|
||||
else:
|
||||
strip_label = node_label.strip('"')
|
||||
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base {{entity_id: "{strip_label}"}})
|
||||
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
||||
RETURN count(distinct m) AS total_nodes
|
||||
$$) AS (total_nodes bigint)"""
|
||||
|
||||
# Build the query based on whether we want the full graph or a specific subgraph.
|
||||
count_result = await self._query(count_query)
|
||||
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
||||
is_truncated = total_nodes > max_nodes
|
||||
|
||||
# Now get the actual data with limit
|
||||
if node_label == "*":
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base)
|
||||
OPTIONAL MATCH (n)-[r]->(target:base)
|
||||
RETURN collect(distinct n) AS n, collect(distinct r) AS r
|
||||
LIMIT {MAX_GRAPH_NODES}
|
||||
LIMIT {max_nodes}
|
||||
$$) AS (n agtype, r agtype)"""
|
||||
else:
|
||||
strip_label = node_label.strip('"')
|
||||
@@ -1559,7 +1576,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
kg = KnowledgeGraph(
|
||||
nodes=list(nodes_dict.values()),
|
||||
edges=list(edges_dict.values()),
|
||||
is_truncated=False,
|
||||
is_truncated=is_truncated,
|
||||
)
|
||||
|
||||
return kg
|
||||
|
Reference in New Issue
Block a user