Fix subgraph filtering bugs

This commit is contained in:
yangdx
2025-03-05 13:13:46 +08:00
parent cb988f20da
commit ca9e958cad

View File

@@ -263,12 +263,14 @@ class NetworkXStorage(BaseGraphStorage):
graph = await self._get_graph() graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", return the entire graph including all nodes and edges # For "*", return the entire graph including all nodes and edges
subgraph = ( subgraph = graph.copy() # Create a copy to avoid modifying the original graph
graph.copy()
) # Create a copy to avoid modifying the original graph
else: else:
# Find nodes with matching node id based on search_mode # Find nodes with matching node id based on search_mode
nodes_to_explore = [] nodes_to_explore = []
@@ -292,10 +294,7 @@ class NetworkXStorage(BaseGraphStorage):
combined_subgraph = nx.compose(combined_subgraph, node_subgraph) combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Get start nodes and direct connected nodes # Get start nodes and direct connected nodes
start_nodes = set() if nodes_to_explore:
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore) start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes # Get nodes directly connected to all start nodes
for start_node in start_nodes: for start_node in start_nodes:
@@ -306,18 +305,17 @@ class NetworkXStorage(BaseGraphStorage):
# Remove start nodes from directly connected nodes (avoid duplicates) # Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes direct_connected_nodes -= start_nodes
subgraph = combined_subgraph
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes # Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0: if min_degree > 0:
nodes_to_keep = [ nodes_to_keep = [
node node
for node, degree in combined_subgraph.degree() for node, degree in subgraph.degree()
if node in start_nodes if (node_label != "*" and (node in start_nodes or node in direct_connected_nodes))
or node in direct_connected_nodes
or degree >= min_degree or degree >= min_degree
] ]
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep) subgraph = subgraph.subgraph(nodes_to_keep)
subgraph = combined_subgraph
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES: if len(subgraph.nodes()) > MAX_GRAPH_NODES: