Added min_degree exception for connected nodes
This commit is contained in:
@@ -31,9 +31,10 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Label matching nodes take precedence
|
||||
2. Followed by nodes directly connected to the matching nodes
|
||||
3. Finally, the degree of the nodes
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
||||
|
||||
Args:
|
||||
|
@@ -242,9 +242,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||
1. Label matching nodes take precedence
|
||||
2. Followed by nodes directly connected to the matching nodes
|
||||
3. Finally, the degree of the nodes
|
||||
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||
2. Label matching nodes take precedence
|
||||
3. Followed by nodes directly connected to the matching nodes
|
||||
4. Finally, the degree of the nodes
|
||||
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
@@ -289,12 +291,25 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||
|
||||
# Filter nodes based on min_degree
|
||||
# Get start nodes and direct connected nodes
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
|
||||
if node_label != "*" and nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
direct_connected_nodes.update(combined_subgraph.neighbors(start_node))
|
||||
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
|
||||
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
||||
if min_degree > 0:
|
||||
nodes_to_keep = [
|
||||
node
|
||||
for node, degree in combined_subgraph.degree()
|
||||
if degree >= min_degree
|
||||
if node in start_nodes or node in direct_connected_nodes or degree >= min_degree
|
||||
]
|
||||
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
|
||||
|
||||
@@ -303,21 +318,8 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
|
||||
node_degrees = dict(subgraph.degree())
|
||||
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
|
||||
if node_label != "*" and nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
direct_connected_nodes.update(subgraph.neighbors(start_node))
|
||||
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
|
||||
def priority_key(node_item):
|
||||
node, degree = node_item
|
||||
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
||||
|
Reference in New Issue
Block a user