Add is_truncate checking for PostgreSQL graph storage
This commit is contained in:
@@ -1479,20 +1479,37 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Args:
|
Args:
|
||||||
node_label: Label of the starting node, * means all nodes
|
node_label: Label of the starting node, * means all nodes
|
||||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||||
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
|
max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||||
indicating whether the graph was truncated due to max_nodes limit
|
indicating whether the graph was truncated due to max_nodes limit
|
||||||
"""
|
"""
|
||||||
|
# First, count the total number of nodes that would be returned without limit
|
||||||
|
if node_label == "*":
|
||||||
|
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
|
MATCH (n:base)
|
||||||
|
RETURN count(distinct n) AS total_nodes
|
||||||
|
$$) AS (total_nodes bigint)"""
|
||||||
|
else:
|
||||||
|
strip_label = node_label.strip('"')
|
||||||
|
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
|
MATCH (n:base {{entity_id: "{strip_label}"}})
|
||||||
|
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
||||||
|
RETURN count(distinct m) AS total_nodes
|
||||||
|
$$) AS (total_nodes bigint)"""
|
||||||
|
|
||||||
# Build the query based on whether we want the full graph or a specific subgraph.
|
count_result = await self._query(count_query)
|
||||||
|
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
||||||
|
is_truncated = total_nodes > max_nodes
|
||||||
|
|
||||||
|
# Now get the actual data with limit
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (n:base)
|
MATCH (n:base)
|
||||||
OPTIONAL MATCH (n)-[r]->(target:base)
|
OPTIONAL MATCH (n)-[r]->(target:base)
|
||||||
RETURN collect(distinct n) AS n, collect(distinct r) AS r
|
RETURN collect(distinct n) AS n, collect(distinct r) AS r
|
||||||
LIMIT {MAX_GRAPH_NODES}
|
LIMIT {max_nodes}
|
||||||
$$) AS (n agtype, r agtype)"""
|
$$) AS (n agtype, r agtype)"""
|
||||||
else:
|
else:
|
||||||
strip_label = node_label.strip('"')
|
strip_label = node_label.strip('"')
|
||||||
@@ -1559,7 +1576,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
kg = KnowledgeGraph(
|
kg = KnowledgeGraph(
|
||||||
nodes=list(nodes_dict.values()),
|
nodes=list(nodes_dict.values()),
|
||||||
edges=list(edges_dict.values()),
|
edges=list(edges_dict.values()),
|
||||||
is_truncated=False,
|
is_truncated=is_truncated,
|
||||||
)
|
)
|
||||||
|
|
||||||
return kg
|
return kg
|
||||||
|
Reference in New Issue
Block a user