From ca9e958cad41d711e742acfe95426d2ef5c93382 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 5 Mar 2025 13:13:46 +0800 Subject: [PATCH] Fix subgraph filtering bugs --- lightrag/kg/networkx_impl.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 8e84dd6f..60122166 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -263,12 +263,14 @@ class NetworkXStorage(BaseGraphStorage): graph = await self._get_graph() + # Initialize sets for start nodes and direct connected nodes + start_nodes = set() + direct_connected_nodes = set() + # Handle special case for "*" label if node_label == "*": # For "*", return the entire graph including all nodes and edges - subgraph = ( - graph.copy() - ) # Create a copy to avoid modifying the original graph + subgraph = graph.copy() # Create a copy to avoid modifying the original graph else: # Find nodes with matching node id based on search_mode nodes_to_explore = [] @@ -292,10 +294,7 @@ class NetworkXStorage(BaseGraphStorage): combined_subgraph = nx.compose(combined_subgraph, node_subgraph) # Get start nodes and direct connected nodes - start_nodes = set() - direct_connected_nodes = set() - - if node_label != "*" and nodes_to_explore: + if nodes_to_explore: start_nodes = set(nodes_to_explore) # Get nodes directly connected to all start nodes for start_node in start_nodes: @@ -306,19 +305,18 @@ class NetworkXStorage(BaseGraphStorage): # Remove start nodes from directly connected nodes (avoid duplicates) direct_connected_nodes -= start_nodes - # Filter nodes based on min_degree, but keep start nodes and direct connected nodes - if min_degree > 0: - nodes_to_keep = [ - node - for node, degree in combined_subgraph.degree() - if node in start_nodes - or node in direct_connected_nodes - or degree >= min_degree - ] - combined_subgraph = combined_subgraph.subgraph(nodes_to_keep) - subgraph = combined_subgraph + # Filter nodes based on min_degree, but keep start nodes and direct connected nodes + if min_degree > 0: + nodes_to_keep = [ + node + for node, degree in subgraph.degree() + if (node_label != "*" and (node in start_nodes or node in direct_connected_nodes)) + or degree >= min_degree + ] + subgraph = subgraph.subgraph(nodes_to_keep) + # Check if number of nodes exceeds max_graph_nodes if len(subgraph.nodes()) > MAX_GRAPH_NODES: origin_nodes = len(subgraph.nodes())