Fix subgraph filtering bugs
This commit is contained in:
@@ -263,12 +263,14 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
graph = await self._get_graph()
|
graph = await self._get_graph()
|
||||||
|
|
||||||
|
# Initialize sets for start nodes and direct connected nodes
|
||||||
|
start_nodes = set()
|
||||||
|
direct_connected_nodes = set()
|
||||||
|
|
||||||
# Handle special case for "*" label
|
# Handle special case for "*" label
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
# For "*", return the entire graph including all nodes and edges
|
# For "*", return the entire graph including all nodes and edges
|
||||||
subgraph = (
|
subgraph = graph.copy() # Create a copy to avoid modifying the original graph
|
||||||
graph.copy()
|
|
||||||
) # Create a copy to avoid modifying the original graph
|
|
||||||
else:
|
else:
|
||||||
# Find nodes with matching node id based on search_mode
|
# Find nodes with matching node id based on search_mode
|
||||||
nodes_to_explore = []
|
nodes_to_explore = []
|
||||||
@@ -292,10 +294,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||||
|
|
||||||
# Get start nodes and direct connected nodes
|
# Get start nodes and direct connected nodes
|
||||||
start_nodes = set()
|
if nodes_to_explore:
|
||||||
direct_connected_nodes = set()
|
|
||||||
|
|
||||||
if node_label != "*" and nodes_to_explore:
|
|
||||||
start_nodes = set(nodes_to_explore)
|
start_nodes = set(nodes_to_explore)
|
||||||
# Get nodes directly connected to all start nodes
|
# Get nodes directly connected to all start nodes
|
||||||
for start_node in start_nodes:
|
for start_node in start_nodes:
|
||||||
@@ -306,18 +305,17 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||||
direct_connected_nodes -= start_nodes
|
direct_connected_nodes -= start_nodes
|
||||||
|
|
||||||
|
subgraph = combined_subgraph
|
||||||
|
|
||||||
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
||||||
if min_degree > 0:
|
if min_degree > 0:
|
||||||
nodes_to_keep = [
|
nodes_to_keep = [
|
||||||
node
|
node
|
||||||
for node, degree in combined_subgraph.degree()
|
for node, degree in subgraph.degree()
|
||||||
if node in start_nodes
|
if (node_label != "*" and (node in start_nodes or node in direct_connected_nodes))
|
||||||
or node in direct_connected_nodes
|
|
||||||
or degree >= min_degree
|
or degree >= min_degree
|
||||||
]
|
]
|
||||||
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
|
subgraph = subgraph.subgraph(nodes_to_keep)
|
||||||
|
|
||||||
subgraph = combined_subgraph
|
|
||||||
|
|
||||||
# Check if number of nodes exceeds max_graph_nodes
|
# Check if number of nodes exceeds max_graph_nodes
|
||||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||||
|
Reference in New Issue
Block a user