Added min_degree exception for connected nodes
This commit is contained in:
@@ -31,9 +31,10 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|||||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||||
1. Label matching nodes take precedence
|
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||||
2. Followed by nodes directly connected to the matching nodes
|
2. Label matching nodes take precedence
|
||||||
3. Finally, the degree of the nodes
|
3. Followed by nodes directly connected to the matching nodes
|
||||||
|
4. Finally, the degree of the nodes
|
||||||
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@@ -242,9 +242,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||||
When reducing the number of nodes, the prioritization criteria are as follows:
|
When reducing the number of nodes, the prioritization criteria are as follows:
|
||||||
1. Label matching nodes take precedence
|
1. min_degree does not affect nodes directly connected to the matching nodes
|
||||||
2. Followed by nodes directly connected to the matching nodes
|
2. Label matching nodes take precedence
|
||||||
3. Finally, the degree of the nodes
|
3. Followed by nodes directly connected to the matching nodes
|
||||||
|
4. Finally, the degree of the nodes
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_label: Label of the starting node
|
node_label: Label of the starting node
|
||||||
@@ -289,12 +291,25 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
||||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||||
|
|
||||||
# Filter nodes based on min_degree
|
# Get start nodes and direct connected nodes
|
||||||
|
start_nodes = set()
|
||||||
|
direct_connected_nodes = set()
|
||||||
|
|
||||||
|
if node_label != "*" and nodes_to_explore:
|
||||||
|
start_nodes = set(nodes_to_explore)
|
||||||
|
# Get nodes directly connected to all start nodes
|
||||||
|
for start_node in start_nodes:
|
||||||
|
direct_connected_nodes.update(combined_subgraph.neighbors(start_node))
|
||||||
|
|
||||||
|
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||||
|
direct_connected_nodes -= start_nodes
|
||||||
|
|
||||||
|
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
||||||
if min_degree > 0:
|
if min_degree > 0:
|
||||||
nodes_to_keep = [
|
nodes_to_keep = [
|
||||||
node
|
node
|
||||||
for node, degree in combined_subgraph.degree()
|
for node, degree in combined_subgraph.degree()
|
||||||
if degree >= min_degree
|
if node in start_nodes or node in direct_connected_nodes or degree >= min_degree
|
||||||
]
|
]
|
||||||
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
|
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
|
||||||
|
|
||||||
@@ -303,21 +318,8 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
# Check if number of nodes exceeds max_graph_nodes
|
# Check if number of nodes exceeds max_graph_nodes
|
||||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||||
origin_nodes = len(subgraph.nodes())
|
origin_nodes = len(subgraph.nodes())
|
||||||
|
|
||||||
node_degrees = dict(subgraph.degree())
|
node_degrees = dict(subgraph.degree())
|
||||||
|
|
||||||
start_nodes = set()
|
|
||||||
direct_connected_nodes = set()
|
|
||||||
|
|
||||||
if node_label != "*" and nodes_to_explore:
|
|
||||||
start_nodes = set(nodes_to_explore)
|
|
||||||
# Get nodes directly connected to all start nodes
|
|
||||||
for start_node in start_nodes:
|
|
||||||
direct_connected_nodes.update(subgraph.neighbors(start_node))
|
|
||||||
|
|
||||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
|
||||||
direct_connected_nodes -= start_nodes
|
|
||||||
|
|
||||||
def priority_key(node_item):
|
def priority_key(node_item):
|
||||||
node, degree = node_item
|
node, degree = node_item
|
||||||
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
# Priority order: start(2) > directly connected(1) > other nodes(0)
|
||||||
|
Reference in New Issue
Block a user