diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 6c247fc3..bad13174 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1879,11 +1879,14 @@ class PGGraphStorage(BaseGraphStorage): visited_node_ids.add(internal_id) result.nodes.append(start_node) - while queue and len(visited_node_ids) < max_nodes: + result.is_truncated = False + + while queue: # Dequeue the next node to process from the front of the queue current_node, current_depth = queue.popleft() - if current_depth >= max_depth: + # Check one more depth for backward edges + if current_depth > max_depth: continue # Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency @@ -1935,6 +1938,9 @@ class PGGraphStorage(BaseGraphStorage): properties=b_node["properties"], ) + # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge + sorted_pair = tuple(sorted([current_entity_id, target_entity_id])) + # Create edge object edge = KnowledgeGraphEdge( id=edge_id, @@ -1944,38 +1950,38 @@ class PGGraphStorage(BaseGraphStorage): properties=rel["properties"], ) - # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge - sorted_pair = tuple(sorted([current_entity_id, target_entity_id])) + if target_internal_id in visited_node_ids: + # Add backward edge if target node is visited + if ( + edge_id not in visited_edges + and sorted_pair not in visited_edge_pairs + ): + result.edges.append(edge) + visited_edges.add(edge_id) + visited_edge_pairs.add(sorted_pair) - # Add edge (if not already added) - if ( - edge_id not in visited_edges - and sorted_pair not in visited_edge_pairs - ): - result.edges.append(edge) - visited_edges.add(edge_id) - visited_edge_pairs.add(sorted_pair) + else: + if len(visited_node_ids) < max_nodes and current_depth < max_depth: + # If target node not yet visited, add to result and queue + result.nodes.append(target_node) + visited_nodes.add(target_entity_id) + visited_node_ids.add(target_internal_id) - # If target node not yet visited, add to result and queue - if target_internal_id not in visited_node_ids: - result.nodes.append(target_node) - visited_nodes.add(target_entity_id) - visited_node_ids.add(target_internal_id) + # Add node to queue with incremented depth + queue.append((target_node, current_depth + 1)) - # Add node to queue with incremented depth - queue.append((target_node, current_depth + 1)) - - # If node limit reached, set truncated flag and exit - if len(visited_node_ids) >= max_nodes: - result.is_truncated = True - logger.info( - f"Graph truncated: BFS limited to {max_nodes} nodes" - ) - break - - # If inner loop reached node limit and exited, also exit outer loop - if len(visited_node_ids) >= max_nodes: - break + # Add forward edge + if ( + edge_id not in visited_edges + and sorted_pair not in visited_edge_pairs + ): + result.edges.append(edge) + visited_edges.add(edge_id) + visited_edge_pairs.add(sorted_pair) + # logger.info(f"Forward edge from {current_entity_id} to {target_entity_id}") + else: + if current_depth < max_depth: + result.is_truncated = True return result