diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index bad13174..99a3dc57 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1881,77 +1881,127 @@ class PGGraphStorage(BaseGraphStorage): result.is_truncated = False + # BFS search main loop while queue: - # Dequeue the next node to process from the front of the queue - current_node, current_depth = queue.popleft() - - # Check one more depth for backward edges + # Get all nodes at the current depth + current_level_nodes = [] + current_depth = None + + # Determine current depth + if queue: + current_depth = queue[0][1] + + # Extract all nodes at current depth from the queue + while queue and queue[0][1] == current_depth: + node, depth = queue.popleft() + if depth > max_depth: + continue + current_level_nodes.append(node) + + if not current_level_nodes: + continue + + # Check depth limit if current_depth > max_depth: continue - - # Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency - current_entity_id = current_node.labels[0] - outgoing_query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]->(b) - WITH r, b, id(r) as edge_id, id(b) as target_id - RETURN r, b, edge_id, target_id - $$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % ( - self.graph_name, - current_entity_id, - ) - incoming_query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})<-[r]-(b) - WITH r, b, id(r) as edge_id, id(b) as target_id - RETURN r, b, edge_id, target_id - $$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % ( - self.graph_name, - current_entity_id, - ) - - outgoing_neighbors = await self._query(outgoing_query) - incoming_neighbors = await self._query(incoming_query) - neighbors = outgoing_neighbors + incoming_neighbors - - # logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})") - + + # Prepare node IDs list + node_ids = [node.labels[0] for node in current_level_nodes] + formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]) + + # Construct batch query for outgoing edges + outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND [{formatted_ids}] AS node_id + MATCH (n:base {{entity_id: node_id}}) + OPTIONAL MATCH (n)-[r]->(neighbor:base) + RETURN node_id AS current_id, + id(n) AS current_internal_id, + id(neighbor) AS neighbor_internal_id, + neighbor.entity_id AS neighbor_id, + id(r) AS edge_id, + r, + neighbor, + true AS is_outgoing + $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, + neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" + + # Construct batch query for incoming edges + incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND [{formatted_ids}] AS node_id + MATCH (n:base {{entity_id: node_id}}) + OPTIONAL MATCH (n)<-[r]-(neighbor:base) + RETURN node_id AS current_id, + id(n) AS current_internal_id, + id(neighbor) AS neighbor_internal_id, + neighbor.entity_id AS neighbor_id, + id(r) AS edge_id, + r, + neighbor, + false AS is_outgoing + $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, + neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" + + # Execute queries + outgoing_results = await self._query(outgoing_query) + incoming_results = await self._query(incoming_query) + + # Combine results + neighbors = outgoing_results + incoming_results + + # Create mapping from node ID to node object + node_map = {node.labels[0]: node for node in current_level_nodes} + + # Process all results in a single loop for record in neighbors: - if not record.get("b") or not record.get("r"): + if not record.get("neighbor") or not record.get("r"): continue - - b_node = record["b"] + + # Get current node information + current_entity_id = record["current_id"] + current_node = node_map[current_entity_id] + + # Get neighbor node information + neighbor_entity_id = record["neighbor_id"] + neighbor_internal_id = str(record["neighbor_internal_id"]) + is_outgoing = record["is_outgoing"] + + # Determine edge direction + if is_outgoing: + source_id = current_node.id + target_id = neighbor_internal_id + else: + source_id = neighbor_internal_id + target_id = current_node.id + + if not neighbor_entity_id: + continue + + # Get edge and node information + b_node = record["neighbor"] rel = record["r"] edge_id = str(record["edge_id"]) - - if ( - "properties" not in b_node - or "entity_id" not in b_node["properties"] - ): - continue - - target_entity_id = b_node["properties"]["entity_id"] - target_internal_id = str(b_node["id"]) - - # Create KnowledgeGraphNode for target - target_node = KnowledgeGraphNode( - id=target_internal_id, - labels=[target_entity_id], + + # Create neighbor node object + neighbor_node = KnowledgeGraphNode( + id=neighbor_internal_id, + labels=[neighbor_entity_id], properties=b_node["properties"], ) - + # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge - sorted_pair = tuple(sorted([current_entity_id, target_entity_id])) - + sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id])) + # Create edge object edge = KnowledgeGraphEdge( id=edge_id, type=rel["label"], - source=current_node.id, - target=target_internal_id, + source=source_id, + target=target_id, properties=rel["properties"], ) - - if target_internal_id in visited_node_ids: - # Add backward edge if target node is visited + + if neighbor_internal_id in visited_node_ids: + # Add backward edge if neighbor node is already visited if ( edge_id not in visited_edges and sorted_pair not in visited_edge_pairs @@ -1959,17 +2009,16 @@ class PGGraphStorage(BaseGraphStorage): result.edges.append(edge) visited_edges.add(edge_id) visited_edge_pairs.add(sorted_pair) - else: if len(visited_node_ids) < max_nodes and current_depth < max_depth: - # If target node not yet visited, add to result and queue - result.nodes.append(target_node) - visited_nodes.add(target_entity_id) - visited_node_ids.add(target_internal_id) - + # Add new node to result and queue + result.nodes.append(neighbor_node) + visited_nodes.add(neighbor_entity_id) + visited_node_ids.add(neighbor_internal_id) + # Add node to queue with incremented depth - queue.append((target_node, current_depth + 1)) - + queue.append((neighbor_node, current_depth + 1)) + # Add forward edge if ( edge_id not in visited_edges @@ -1978,7 +2027,6 @@ class PGGraphStorage(BaseGraphStorage): result.edges.append(edge) visited_edges.add(edge_id) visited_edge_pairs.add(sorted_pair) - # logger.info(f"Forward edge from {current_entity_id} to {target_entity_id}") else: if current_depth < max_depth: result.is_truncated = True