Added min_degree exception for connected nodes

This commit is contained in:
yangdx
2025-03-05 11:48:04 +08:00
parent 1fddc8552e
commit 5e40e4107d
2 changed files with 24 additions and 21 deletions

View File

@@ -31,9 +31,10 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Retrieve a connected subgraph of nodes where the label includes the specified label. Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows: When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence 1. min_degree does not affect nodes directly connected to the matching nodes
2. Followed by nodes directly connected to the matching nodes 2. Label matching nodes take precedence
3. Finally, the degree of the nodes 3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args: Args:

View File

@@ -242,9 +242,11 @@ class NetworkXStorage(BaseGraphStorage):
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows: When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence 1. min_degree does not affect nodes directly connected to the matching nodes
2. Followed by nodes directly connected to the matching nodes 2. Label matching nodes take precedence
3. Finally, the degree of the nodes 3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node
@@ -289,12 +291,25 @@ class NetworkXStorage(BaseGraphStorage):
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph) combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Filter nodes based on min_degree # Get start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(combined_subgraph.neighbors(start_node))
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0: if min_degree > 0:
nodes_to_keep = [ nodes_to_keep = [
node node
for node, degree in combined_subgraph.degree() for node, degree in combined_subgraph.degree()
if degree >= min_degree if node in start_nodes or node in direct_connected_nodes or degree >= min_degree
] ]
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep) combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
@@ -303,21 +318,8 @@ class NetworkXStorage(BaseGraphStorage):
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES: if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(subgraph.neighbors(start_node))
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
def priority_key(node_item): def priority_key(node_item):
node, degree = node_item node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0) # Priority order: start(2) > directly connected(1) > other nodes(0)