diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index d9380094..124240a7 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -31,9 +31,10 @@ def create_graph_routes(rag, api_key: Optional[str] = None): Retrieve a connected subgraph of nodes where the label includes the specified label. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence - 2. Followed by nodes directly connected to the matching nodes - 3. Finally, the degree of the nodes + 1. min_degree does not affect nodes directly connected to the matching nodes + 2. Label matching nodes take precedence + 3. Followed by nodes directly connected to the matching nodes + 4. Finally, the degree of the nodes Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) Args: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 2c76a67d..84518523 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -242,9 +242,11 @@ class NetworkXStorage(BaseGraphStorage): Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence - 2. Followed by nodes directly connected to the matching nodes - 3. Finally, the degree of the nodes + 1. min_degree does not affect nodes directly connected to the matching nodes + 2. Label matching nodes take precedence + 3. Followed by nodes directly connected to the matching nodes + 4. Finally, the degree of the nodes + Args: node_label: Label of the starting node @@ -289,12 +291,25 @@ class NetworkXStorage(BaseGraphStorage): node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) combined_subgraph = nx.compose(combined_subgraph, node_subgraph) - # Filter nodes based on min_degree + # Get start nodes and direct connected nodes + start_nodes = set() + direct_connected_nodes = set() + + if node_label != "*" and nodes_to_explore: + start_nodes = set(nodes_to_explore) + # Get nodes directly connected to all start nodes + for start_node in start_nodes: + direct_connected_nodes.update(combined_subgraph.neighbors(start_node)) + + # Remove start nodes from directly connected nodes (avoid duplicates) + direct_connected_nodes -= start_nodes + + # Filter nodes based on min_degree, but keep start nodes and direct connected nodes if min_degree > 0: nodes_to_keep = [ node for node, degree in combined_subgraph.degree() - if degree >= min_degree + if node in start_nodes or node in direct_connected_nodes or degree >= min_degree ] combined_subgraph = combined_subgraph.subgraph(nodes_to_keep) @@ -303,21 +318,8 @@ class NetworkXStorage(BaseGraphStorage): # Check if number of nodes exceeds max_graph_nodes if len(subgraph.nodes()) > MAX_GRAPH_NODES: origin_nodes = len(subgraph.nodes()) - node_degrees = dict(subgraph.degree()) - start_nodes = set() - direct_connected_nodes = set() - - if node_label != "*" and nodes_to_explore: - start_nodes = set(nodes_to_explore) - # Get nodes directly connected to all start nodes - for start_node in start_nodes: - direct_connected_nodes.update(subgraph.neighbors(start_node)) - - # Remove start nodes from directly connected nodes (avoid duplicates) - direct_connected_nodes -= start_nodes - def priority_key(node_item): node, degree = node_item # Priority order: start(2) > directly connected(1) > other nodes(0)