From 82c4baba70883a91c9790f37ea7e436507bc2cce Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 2 Apr 2025 22:12:20 +0800 Subject: [PATCH] Add is_truncated to graph query for NetworkX graph db --- lightrag/base.py | 13 ++++++++++++- lightrag/kg/neo4j_impl.py | 9 +++------ lightrag/kg/networkx_impl.py | 22 +++++++++++++++++----- lightrag/types.py | 1 + 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index ec7ba9fa..0d387cc3 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -343,7 +343,18 @@ class BaseGraphStorage(StorageNameSpace, ABC): async def get_knowledge_graph( self, node_label: str, max_depth: int = 3, max_nodes: int = 1000 ) -> KnowledgeGraph: - """Retrieve a subgraph of the knowledge graph starting from a given node.""" + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node,* means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + """ class DocStatus(str, Enum): diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 1244d4cb..3c7e57a7 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -651,17 +651,14 @@ class Neo4JStorage(BaseGraphStorage): ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. - When reducing the number of nodes, the prioritization criteria are as follows: - 1. Hops(path) to the staring node take precedence - 2. Followed by the degree of the nodes Args: - node_label: Label of the starting node + node_label: Label of the starting node,* means all nodes max_depth: Maximum depth of the subgraph, Defaults to 3 - max_nodes: Maxiumu nodes to return + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 Returns: - KnowledgeGraph: Complete connected subgraph for specified node + KnowledgeGraph object containing nodes and edges """ result = KnowledgeGraph() seen_nodes = set() diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index f3483afa..c637ff01 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -270,16 +270,24 @@ class NetworkXStorage(BaseGraphStorage): max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 Returns: - KnowledgeGraph object containing nodes and edges + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit """ graph = await self._get_graph() + result = KnowledgeGraph() + # Handle special case for "*" label if node_label == "*": # Get degrees of all nodes degrees = dict(graph.degree()) # Sort nodes by degree in descending order and take top max_nodes sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True) + + # Check if graph is truncated + if len(sorted_nodes) > max_nodes: + result.is_truncated = True + limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]] # Create subgraph with the highest degree nodes subgraph = graph.subgraph(limited_nodes) @@ -293,23 +301,27 @@ class NetworkXStorage(BaseGraphStorage): bfs_nodes = [] visited = set() queue = [node_label] - + # Breadth-first search while queue and len(bfs_nodes) < max_nodes: current = queue.pop(0) if current not in visited: visited.add(current) bfs_nodes.append(current) - + # Add neighbor nodes to queue neighbors = list(graph.neighbors(current)) queue.extend([n for n in neighbors if n not in visited]) - + + # Check if graph is truncated - if we still have nodes in the queue + # and we've reached max_nodes, then the graph is truncated + if queue and len(bfs_nodes) >= max_nodes: + result.is_truncated = True + # Create subgraph with BFS discovered nodes subgraph = graph.subgraph(bfs_nodes) # Add nodes to result - result = KnowledgeGraph() seen_nodes = set() seen_edges = set() for node in subgraph.nodes(): diff --git a/lightrag/types.py b/lightrag/types.py index 5e3d2948..a18f2d3c 100644 --- a/lightrag/types.py +++ b/lightrag/types.py @@ -26,3 +26,4 @@ class KnowledgeGraphEdge(BaseModel): class KnowledgeGraph(BaseModel): nodes: list[KnowledgeGraphNode] = [] edges: list[KnowledgeGraphEdge] = [] + is_truncated: bool = False