diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 7b7503dd..28d1aa3b 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1479,20 +1479,37 @@ class PGGraphStorage(BaseGraphStorage): Args: node_label: Label of the starting node, * means all nodes max_depth: Maximum depth of the subgraph, Defaults to 3 - max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed) Returns: KnowledgeGraph object containing nodes and edges, with an is_truncated flag indicating whether the graph was truncated due to max_nodes limit """ + # First, count the total number of nodes that would be returned without limit + if node_label == "*": + count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base) + RETURN count(distinct n) AS total_nodes + $$) AS (total_nodes bigint)""" + else: + strip_label = node_label.strip('"') + count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base {{entity_id: "{strip_label}"}}) + OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) + RETURN count(distinct m) AS total_nodes + $$) AS (total_nodes bigint)""" - # Build the query based on whether we want the full graph or a specific subgraph. + count_result = await self._query(count_query) + total_nodes = count_result[0]["total_nodes"] if count_result else 0 + is_truncated = total_nodes > max_nodes + + # Now get the actual data with limit if node_label == "*": query = f"""SELECT * FROM cypher('{self.graph_name}', $$ MATCH (n:base) OPTIONAL MATCH (n)-[r]->(target:base) RETURN collect(distinct n) AS n, collect(distinct r) AS r - LIMIT {MAX_GRAPH_NODES} + LIMIT {max_nodes} $$) AS (n agtype, r agtype)""" else: strip_label = node_label.strip('"') @@ -1559,7 +1576,7 @@ class PGGraphStorage(BaseGraphStorage): kg = KnowledgeGraph( nodes=list(nodes_dict.values()), edges=list(edges_dict.values()), - is_truncated=False, + is_truncated=is_truncated, ) return kg