Fix subgraph filtering bugs
This commit is contained in:
@@ -263,12 +263,14 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
graph = await self._get_graph()
|
||||
|
||||
# Initialize sets for start nodes and direct connected nodes
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
subgraph = (
|
||||
graph.copy()
|
||||
) # Create a copy to avoid modifying the original graph
|
||||
subgraph = graph.copy() # Create a copy to avoid modifying the original graph
|
||||
else:
|
||||
# Find nodes with matching node id based on search_mode
|
||||
nodes_to_explore = []
|
||||
@@ -292,10 +294,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||
|
||||
# Get start nodes and direct connected nodes
|
||||
start_nodes = set()
|
||||
direct_connected_nodes = set()
|
||||
|
||||
if node_label != "*" and nodes_to_explore:
|
||||
if nodes_to_explore:
|
||||
start_nodes = set(nodes_to_explore)
|
||||
# Get nodes directly connected to all start nodes
|
||||
for start_node in start_nodes:
|
||||
@@ -306,18 +305,17 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
# Remove start nodes from directly connected nodes (avoid duplicates)
|
||||
direct_connected_nodes -= start_nodes
|
||||
|
||||
subgraph = combined_subgraph
|
||||
|
||||
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
|
||||
if min_degree > 0:
|
||||
nodes_to_keep = [
|
||||
node
|
||||
for node, degree in combined_subgraph.degree()
|
||||
if node in start_nodes
|
||||
or node in direct_connected_nodes
|
||||
for node, degree in subgraph.degree()
|
||||
if (node_label != "*" and (node in start_nodes or node in direct_connected_nodes))
|
||||
or degree >= min_degree
|
||||
]
|
||||
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
|
||||
|
||||
subgraph = combined_subgraph
|
||||
subgraph = subgraph.subgraph(nodes_to_keep)
|
||||
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
|
||||
|
Reference in New Issue
Block a user