Optimize NetworkX subgraph query

This commit is contained in:
yangdx
2025-04-02 21:41:24 +08:00
parent a50edffdb0
commit 4ceafb7cbc

View File

@@ -259,118 +259,59 @@ class NetworkXStorage(BaseGraphStorage):
self,
node_label: str,
max_depth: int = 3,
min_degree: int = 0,
inclusive: bool = False,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args:
node_label: Label of the starting node
max_depth: Maximum depth of the subgraph
min_degree: Minimum degree of nodes to include. Defaults to 0
inclusive: Do an inclusive search if true
node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges
"""
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
# Handle special case for "*" label
if node_label == "*":
# For "*", return the entire graph including all nodes and edges
subgraph = (
graph.copy()
) # Create a copy to avoid modifying the original graph
# Get degrees of all nodes
degrees = dict(graph.degree())
# Sort nodes by degree in descending order and take top max_nodes
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
# Create subgraph with the highest degree nodes
subgraph = graph.subgraph(limited_nodes)
else:
# Find nodes with matching node id based on search_mode
nodes_to_explore = []
for n, attr in graph.nodes(data=True):
node_str = str(n)
if not inclusive:
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n)
# Check if node exists
if node_label not in graph:
logger.warning(f"Node {node_label} not found in the graph")
return KnowledgeGraph() # Return empty graph
if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}")
return result
# Use BFS to get nodes
bfs_nodes = []
visited = set()
queue = [node_label]
# Get subgraph using ego_graph from all matching nodes
combined_subgraph = nx.Graph()
for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Breadth-first search
while queue and len(bfs_nodes) < max_nodes:
current = queue.pop(0)
if current not in visited:
visited.add(current)
bfs_nodes.append(current)
# Get start nodes and direct connected nodes
if nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(
combined_subgraph.neighbors(start_node)
)
# Add neighbor nodes to queue
neighbors = list(graph.neighbors(current))
queue.extend([n for n in neighbors if n not in visited])
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
subgraph = combined_subgraph
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0:
nodes_to_keep = [
node
for node, degree in subgraph.degree()
if node in start_nodes
or node in direct_connected_nodes
or degree >= min_degree
]
subgraph = subgraph.subgraph(nodes_to_keep)
# Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
def priority_key(node_item):
node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0)
if node in start_nodes:
priority = 2
elif node in direct_connected_nodes:
priority = 1
else:
priority = 0
return (priority, degree)
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
:MAX_GRAPH_NODES
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph and keep nodes only with most degree
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
)
# Create subgraph with BFS discovered nodes
subgraph = graph.subgraph(bfs_nodes)
# Add nodes to result
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
for node in subgraph.nodes():
if str(node) in seen_nodes:
continue
@@ -398,7 +339,7 @@ class NetworkXStorage(BaseGraphStorage):
for edge in subgraph.edges():
source, target = edge
# Esure unique edge_id for undirect graph
if source > target:
if str(source) > str(target):
source, target = target, source
edge_id = f"{source}-{target}"
if edge_id in seen_edges: