diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index fb11b31c..213eced4 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1824,6 +1824,159 @@ class PGGraphStorage(BaseGraphStorage): labels.append(result["label"]) return labels + async def _bfs_subgraph( + self, node_label: str, max_depth: int, max_nodes: int + ) -> KnowledgeGraph: + """ + Implements a true breadth-first search algorithm for subgraph retrieval. + This method is used as a fallback when the standard Cypher query is too slow + or when we need to guarantee BFS ordering. + + Args: + node_label: Label of the starting node + max_depth: Maximum depth of the subgraph + max_nodes: Maximum number of nodes to return + + Returns: + KnowledgeGraph object containing nodes and edges + """ + from collections import deque + + result = KnowledgeGraph() + visited_nodes = set() + visited_node_ids = set() + visited_edges = set() + visited_edge_pairs = set() + + # Get starting node data + label = self._normalize_node_id(node_label) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:base {entity_id: "%s"}) + RETURN id(n) as node_id, n + $$) AS (node_id bigint, n agtype)""" % (self.graph_name, label) + + node_result = await self._query(query) + if not node_result or not node_result[0].get("n"): + return result + + # Create initial KnowledgeGraphNode + start_node_data = node_result[0]["n"] + entity_id = start_node_data["properties"]["entity_id"] + internal_id = str(start_node_data["id"]) + + start_node = KnowledgeGraphNode( + id=internal_id, + labels=[entity_id], + properties=start_node_data["properties"], + ) + + # Initialize BFS queue, each element is a tuple of (node, depth) + queue = deque([(start_node, 0)]) + + visited_nodes.add(entity_id) + visited_node_ids.add(internal_id) + result.nodes.append(start_node) + + while queue and len(visited_node_ids) < max_nodes: + # Dequeue the next node to process from the front of the queue + current_node, current_depth = queue.popleft() + + if current_depth >= max_depth: + continue + + # Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency + current_entity_id = current_node.labels[0] + outgoing_query = """SELECT * FROM cypher('%s', $$ + MATCH (a:base {entity_id: "%s"})-[r]->(b) + WITH r, b, id(r) as edge_id, id(b) as target_id + RETURN r, b, edge_id, target_id + $$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % ( + self.graph_name, + current_entity_id, + ) + incoming_query = """SELECT * FROM cypher('%s', $$ + MATCH (a:base {entity_id: "%s"})<-[r]-(b) + WITH r, b, id(r) as edge_id, id(b) as target_id + RETURN r, b, edge_id, target_id + $$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % ( + self.graph_name, + current_entity_id, + ) + + outgoing_neighbors = await self._query(outgoing_query) + incoming_neighbors = await self._query(incoming_query) + neighbors = outgoing_neighbors + incoming_neighbors + + # logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})") + + for record in neighbors: + if not record.get("b") or not record.get("r"): + continue + + b_node = record["b"] + rel = record["r"] + edge_id = str(record["edge_id"]) + + if ( + "properties" not in b_node + or "entity_id" not in b_node["properties"] + ): + continue + + target_entity_id = b_node["properties"]["entity_id"] + target_internal_id = str(b_node["id"]) + + # Create KnowledgeGraphNode for target + target_node = KnowledgeGraphNode( + id=target_internal_id, + labels=[target_entity_id], + properties=b_node["properties"], + ) + + # Create edge object + edge = KnowledgeGraphEdge( + id=edge_id, + type=rel["label"], + source=current_node.id, + target=target_internal_id, + properties=rel["properties"], + ) + + # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge + sorted_pair = tuple(sorted([current_entity_id, target_entity_id])) + + # Add edge (if not already added) + if ( + edge_id not in visited_edges + and sorted_pair not in visited_edge_pairs + ): + result.edges.append(edge) + visited_edges.add(edge_id) + visited_edge_pairs.add(sorted_pair) + + # If target node not yet visited, add to result and queue + if target_internal_id not in visited_node_ids: + result.nodes.append(target_node) + visited_nodes.add(target_entity_id) + visited_node_ids.add(target_internal_id) + + # Add node to queue with incremented depth + queue.append((target_node, current_depth + 1)) + + # If node limit reached, set truncated flag and exit + if len(visited_node_ids) >= max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: BFS limited to {max_nodes} nodes" + ) + break + + # If inner loop reached node limit and exited, also exit outer loop + if len(visited_node_ids) >= max_nodes: + break + + return result + async def get_knowledge_graph( self, node_label: str, @@ -1836,110 +1989,71 @@ class PGGraphStorage(BaseGraphStorage): Args: node_label: Label of the starting node, * means all nodes max_depth: Maximum depth of the subgraph, Defaults to 3 - max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed) + max_nodes: Maxiumu nodes to return, Defaults to 1000 Returns: KnowledgeGraph object containing nodes and edges, with an is_truncated flag indicating whether the graph was truncated due to max_nodes limit """ - # First, count the total number of nodes that would be returned without limit + # Handle wildcard query - get all nodes if node_label == "*": + # First check total node count to determine if graph should be truncated count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ MATCH (n:base) RETURN count(distinct n) AS total_nodes $$) AS (total_nodes bigint)""" - else: - strip_label = self._normalize_node_id(node_label) - count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:base {{entity_id: "{strip_label}"}})-[r]-() - RETURN count(r) AS total_nodes - $$) AS (total_nodes bigint)""" - count_result = await self._query(count_query) - total_nodes = count_result[0]["total_nodes"] if count_result else 0 - is_truncated = total_nodes > max_nodes + count_result = await self._query(count_query) + total_nodes = count_result[0]["total_nodes"] if count_result else 0 + is_truncated = total_nodes > max_nodes - # Now get the actual data with limit - if node_label == "*": + # Get nodes and edges query = f"""SELECT * FROM cypher('{self.graph_name}', $$ MATCH (node:base) OPTIONAL MATCH (node)-[r]->() RETURN collect(distinct node) AS n, collect(distinct r) AS r LIMIT {max_nodes} $$) AS (n agtype, r agtype)""" + + results = await self._query(query) + + # Process query results, deduplicate nodes and edges + nodes_dict = {} + edges_dict = {} + + for result in results: + if result.get("n") and isinstance(result["n"], list): + for node in result["n"]: + if isinstance(node, dict) and "id" in node: + node_id = str(node["id"]) + if node_id not in nodes_dict and "properties" in node: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[node["properties"]["entity_id"]], + properties=node["properties"], + ) + + if result.get("r") and isinstance(result["r"], list): + for edge in result["r"]: + if isinstance(edge, dict) and "id" in edge: + edge_id = str(edge["id"]) + if edge_id not in edges_dict: + edges_dict[edge_id] = KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(edge["start_id"]), + target=str(edge["end_id"]), + properties=edge["properties"], + ) + + kg = KnowledgeGraph( + nodes=list(nodes_dict.values()), + edges=list(edges_dict.values()), + is_truncated=is_truncated, + ) else: - strip_label = self._normalize_node_id(node_label) - if total_nodes > 0: - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (node:base {{entity_id: "{strip_label}"}}) - OPTIONAL MATCH p = (node)-[*..{max_depth}]-() - RETURN nodes(p) AS n, relationships(p) AS r - LIMIT {max_nodes} - $$) AS (n agtype, r agtype)""" - else: - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (node:base {{entity_id: "{strip_label}"}}) - RETURN node AS n - $$) AS (n agtype)""" - - results = await self._query(query) - - # Process the query results with deduplication by node and edge IDs - nodes_dict = {} - edges_dict = {} - for result in results: - # Handle single node cases - if result.get("n") and isinstance(result["n"], dict): - node_id = str(result["n"]["id"]) - if node_id not in nodes_dict: - nodes_dict[node_id] = KnowledgeGraphNode( - id=node_id, - labels=[result["n"]["properties"]["entity_id"]], - properties=result["n"]["properties"], - ) - # Handle node list cases - elif result.get("n") and isinstance(result["n"], list): - for node in result["n"]: - if isinstance(node, dict) and "id" in node: - node_id = str(node["id"]) - if node_id not in nodes_dict and "properties" in node: - nodes_dict[node_id] = KnowledgeGraphNode( - id=node_id, - labels=[node["properties"]["entity_id"]], - properties=node["properties"], - ) - - # Handle single edge cases - if result.get("r") and isinstance(result["r"], dict): - edge_id = str(result["r"]["id"]) - if edge_id not in edges_dict: - edges_dict[edge_id] = KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(result["r"]["start_id"]), - target=str(result["r"]["end_id"]), - properties=result["r"]["properties"], - ) - # Handle edge list cases - elif result.get("r") and isinstance(result["r"], list): - for edge in result["r"]: - if isinstance(edge, dict) and "id" in edge: - edge_id = str(edge["id"]) - if edge_id not in edges_dict: - edges_dict[edge_id] = KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(edge["start_id"]), - target=str(edge["end_id"]), - properties=edge["properties"], - ) - - # Construct and return the KnowledgeGraph with deduplicated nodes and edges - kg = KnowledgeGraph( - nodes=list(nodes_dict.values()), - edges=list(edges_dict.values()), - is_truncated=is_truncated, - ) + # For single node query, use BFS algorithm + kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) logger.info( f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"