Optimize NetworkX subgraph query

This commit is contained in:
yangdx
2025-04-02 21:41:24 +08:00
parent a50edffdb0
commit 4ceafb7cbc

View File

@@ -259,118 +259,59 @@ class NetworkXStorage(BaseGraphStorage):
self, self,
node_label: str, node_label: str,
max_depth: int = 3, max_depth: int = 3,
min_degree: int = 0, max_nodes: int = MAX_GRAPH_NODES,
inclusive: bool = False,
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. min_degree does not affect nodes directly connected to the matching nodes
2. Label matching nodes take precedence
3. Followed by nodes directly connected to the matching nodes
4. Finally, the degree of the nodes
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node* means all nodes
max_depth: Maximum depth of the subgraph max_depth: Maximum depth of the subgraph, Defaults to 3
min_degree: Minimum degree of nodes to include. Defaults to 0 max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
inclusive: Do an inclusive search if true
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges
""" """
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
graph = await self._get_graph() graph = await self._get_graph()
# Initialize sets for start nodes and direct connected nodes
start_nodes = set()
direct_connected_nodes = set()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", return the entire graph including all nodes and edges # Get degrees of all nodes
subgraph = ( degrees = dict(graph.degree())
graph.copy() # Sort nodes by degree in descending order and take top max_nodes
) # Create a copy to avoid modifying the original graph sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
# Create subgraph with the highest degree nodes
subgraph = graph.subgraph(limited_nodes)
else: else:
# Find nodes with matching node id based on search_mode # Check if node exists
nodes_to_explore = [] if node_label not in graph:
for n, attr in graph.nodes(data=True): logger.warning(f"Node {node_label} not found in the graph")
node_str = str(n) return KnowledgeGraph() # Return empty graph
if not inclusive:
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n)
if not nodes_to_explore: # Use BFS to get nodes
logger.warning(f"No nodes found with label {node_label}") bfs_nodes = []
return result visited = set()
queue = [node_label]
# Get subgraph using ego_graph from all matching nodes # Breadth-first search
combined_subgraph = nx.Graph() while queue and len(bfs_nodes) < max_nodes:
for start_node in nodes_to_explore: current = queue.pop(0)
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) if current not in visited:
combined_subgraph = nx.compose(combined_subgraph, node_subgraph) visited.add(current)
bfs_nodes.append(current)
# Get start nodes and direct connected nodes # Add neighbor nodes to queue
if nodes_to_explore: neighbors = list(graph.neighbors(current))
start_nodes = set(nodes_to_explore) queue.extend([n for n in neighbors if n not in visited])
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(
combined_subgraph.neighbors(start_node)
)
# Remove start nodes from directly connected nodes (avoid duplicates) # Create subgraph with BFS discovered nodes
direct_connected_nodes -= start_nodes subgraph = graph.subgraph(bfs_nodes)
subgraph = combined_subgraph
# Filter nodes based on min_degree, but keep start nodes and direct connected nodes
if min_degree > 0:
nodes_to_keep = [
node
for node, degree in subgraph.degree()
if node in start_nodes
or node in direct_connected_nodes
or degree >= min_degree
]
subgraph = subgraph.subgraph(nodes_to_keep)
# Check if number of nodes exceeds max_graph_nodes
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
def priority_key(node_item):
node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0)
if node in start_nodes:
priority = 2
elif node in direct_connected_nodes:
priority = 1
else:
priority = 0
return (priority, degree)
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
:MAX_GRAPH_NODES
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph and keep nodes only with most degree
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
)
# Add nodes to result # Add nodes to result
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
for node in subgraph.nodes(): for node in subgraph.nodes():
if str(node) in seen_nodes: if str(node) in seen_nodes:
continue continue
@@ -398,7 +339,7 @@ class NetworkXStorage(BaseGraphStorage):
for edge in subgraph.edges(): for edge in subgraph.edges():
source, target = edge source, target = edge
# Esure unique edge_id for undirect graph # Esure unique edge_id for undirect graph
if source > target: if str(source) > str(target):
source, target = target, source source, target = target, source
edge_id = f"{source}-{target}" edge_id = f"{source}-{target}"
if edge_id in seen_edges: if edge_id in seen_edges: