diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 69cff169..6cc6852f 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -303,18 +303,20 @@ class NetworkXStorage(BaseGraphStorage): # Use BFS to get nodes bfs_nodes = [] visited = set() - queue = [node_label] + queue = [(node_label, 0)] # (node, depth) tuple # Breadth-first search while queue and len(bfs_nodes) < max_nodes: - current = queue.pop(0) + current, depth = queue.pop(0) if current not in visited: visited.add(current) bfs_nodes.append(current) - # Add neighbor nodes to queue - neighbors = list(graph.neighbors(current)) - queue.extend([n for n in neighbors if n not in visited]) + # Only explore neighbors if we haven't reached max_depth + if depth < max_depth: + # Add neighbor nodes to queue with incremented depth + neighbors = list(graph.neighbors(current)) + queue.extend([(n, depth + 1) for n in neighbors if n not in visited]) # Check if graph is truncated - if we still have nodes in the queue # and we've reached max_nodes, then the graph is truncated