From 5eb019a7fc932bfda80bc1f1448e947c5771cfaf Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 25 Apr 2025 21:25:37 +0800 Subject: [PATCH] Fix max_nodes not working in graph queries when using the '*' wildcard. --- lightrag/kg/postgres_impl.py | 197 +++++++++++++++++++++-------------- 1 file changed, 118 insertions(+), 79 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 99a3dc57..1d991040 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1886,85 +1886,87 @@ class PGGraphStorage(BaseGraphStorage): # Get all nodes at the current depth current_level_nodes = [] current_depth = None - + # Determine current depth if queue: current_depth = queue[0][1] - + # Extract all nodes at current depth from the queue while queue and queue[0][1] == current_depth: node, depth = queue.popleft() if depth > max_depth: continue current_level_nodes.append(node) - + if not current_level_nodes: continue - + # Check depth limit if current_depth > max_depth: continue - + # Prepare node IDs list node_ids = [node.labels[0] for node in current_level_nodes] - formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]) - + formatted_ids = ", ".join( + [f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids] + ) + # Construct batch query for outgoing edges outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ UNWIND [{formatted_ids}] AS node_id MATCH (n:base {{entity_id: node_id}}) OPTIONAL MATCH (n)-[r]->(neighbor:base) - RETURN node_id AS current_id, - id(n) AS current_internal_id, - id(neighbor) AS neighbor_internal_id, - neighbor.entity_id AS neighbor_id, - id(r) AS edge_id, - r, + RETURN node_id AS current_id, + id(n) AS current_internal_id, + id(neighbor) AS neighbor_internal_id, + neighbor.entity_id AS neighbor_id, + id(r) AS edge_id, + r, neighbor, true AS is_outgoing - $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, + $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" - + # Construct batch query for incoming edges incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ UNWIND [{formatted_ids}] AS node_id MATCH (n:base {{entity_id: node_id}}) OPTIONAL MATCH (n)<-[r]-(neighbor:base) - RETURN node_id AS current_id, - id(n) AS current_internal_id, - id(neighbor) AS neighbor_internal_id, - neighbor.entity_id AS neighbor_id, - id(r) AS edge_id, - r, + RETURN node_id AS current_id, + id(n) AS current_internal_id, + id(neighbor) AS neighbor_internal_id, + neighbor.entity_id AS neighbor_id, + id(r) AS edge_id, + r, neighbor, false AS is_outgoing - $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, + $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" - + # Execute queries outgoing_results = await self._query(outgoing_query) incoming_results = await self._query(incoming_query) - + # Combine results neighbors = outgoing_results + incoming_results - + # Create mapping from node ID to node object node_map = {node.labels[0]: node for node in current_level_nodes} - + # Process all results in a single loop for record in neighbors: if not record.get("neighbor") or not record.get("r"): continue - + # Get current node information current_entity_id = record["current_id"] current_node = node_map[current_entity_id] - + # Get neighbor node information neighbor_entity_id = record["neighbor_id"] neighbor_internal_id = str(record["neighbor_internal_id"]) is_outgoing = record["is_outgoing"] - + # Determine edge direction if is_outgoing: source_id = current_node.id @@ -1972,25 +1974,25 @@ class PGGraphStorage(BaseGraphStorage): else: source_id = neighbor_internal_id target_id = current_node.id - + if not neighbor_entity_id: continue - + # Get edge and node information b_node = record["neighbor"] rel = record["r"] edge_id = str(record["edge_id"]) - + # Create neighbor node object neighbor_node = KnowledgeGraphNode( id=neighbor_internal_id, labels=[neighbor_entity_id], properties=b_node["properties"], ) - + # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id])) - + # Create edge object edge = KnowledgeGraphEdge( id=edge_id, @@ -1999,7 +2001,7 @@ class PGGraphStorage(BaseGraphStorage): target=target_id, properties=rel["properties"], ) - + if neighbor_internal_id in visited_node_ids: # Add backward edge if neighbor node is already visited if ( @@ -2015,10 +2017,10 @@ class PGGraphStorage(BaseGraphStorage): result.nodes.append(neighbor_node) visited_nodes.add(neighbor_entity_id) visited_node_ids.add(neighbor_internal_id) - + # Add node to queue with incremented depth queue.append((neighbor_node, current_depth + 1)) - + # Add forward edge if ( edge_id not in visited_edges @@ -2051,6 +2053,9 @@ class PGGraphStorage(BaseGraphStorage): KnowledgeGraph object containing nodes and edges, with an is_truncated flag indicating whether the graph was truncated due to max_nodes limit """ + # 初始化 kg 变量,确保在所有情况下都有定义 + kg = KnowledgeGraph() + # Handle wildcard query - get all nodes if node_label == "*": # First check total node count to determine if graph should be truncated @@ -2063,57 +2068,91 @@ class PGGraphStorage(BaseGraphStorage): total_nodes = count_result[0]["total_nodes"] if count_result else 0 is_truncated = total_nodes > max_nodes - # Get nodes and edges - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (node:base) - OPTIONAL MATCH (node)-[r]->() - RETURN collect(distinct node) AS n, collect(distinct r) AS r - LIMIT {max_nodes} - $$) AS (n agtype, r agtype)""" + # Get max_nodes with highest degrees + query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base) + OPTIONAL MATCH (n)-[r]->() + RETURN id(n) as node_id, count(r) as degree + $$) AS (node_id BIGINT, degree BIGINT) + ORDER BY degree DESC + LIMIT {max_nodes}""" + node_results = await self._query(query_nodes) - results = await self._query(query) + node_ids = [str(result["node_id"]) for result in node_results] - # Process query results, deduplicate nodes and edges - nodes_dict = {} - edges_dict = {} + logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}") - for result in results: - if result.get("n") and isinstance(result["n"], list): - for node in result["n"]: - if isinstance(node, dict) and "id" in node: - node_id = str(node["id"]) - if node_id not in nodes_dict and "properties" in node: - nodes_dict[node_id] = KnowledgeGraphNode( - id=node_id, - labels=[node["properties"]["entity_id"]], - properties=node["properties"], - ) + if node_ids: + formatted_ids = ", ".join(node_ids) + # Construct batch query for subgraph within max_nodes + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + WITH [{formatted_ids}] AS node_ids + MATCH (a) + WHERE id(a) IN node_ids + OPTIONAL MATCH (a)-[r]->(b) + WHERE id(b) IN node_ids + RETURN a, r, b + $$) AS (a AGTYPE, r AGTYPE, b AGTYPE)""" + results = await self._query(query) - if result.get("r") and isinstance(result["r"], list): - for edge in result["r"]: - if isinstance(edge, dict) and "id" in edge: - edge_id = str(edge["id"]) - if edge_id not in edges_dict: - edges_dict[edge_id] = KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(edge["start_id"]), - target=str(edge["end_id"]), - properties=edge["properties"], - ) + # Process query results, deduplicate nodes and edges + nodes_dict = {} + edges_dict = {} + for result in results: + # 处理节点 a + if result.get("a") and isinstance(result["a"], dict): + node_a = result["a"] + node_id = str(node_a["id"]) + if node_id not in nodes_dict and "properties" in node_a: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[node_a["properties"]["entity_id"]], + properties=node_a["properties"], + ) - kg = KnowledgeGraph( - nodes=list(nodes_dict.values()), - edges=list(edges_dict.values()), - is_truncated=is_truncated, + # 处理节点 b + if result.get("b") and isinstance(result["b"], dict): + node_b = result["b"] + node_id = str(node_b["id"]) + if node_id not in nodes_dict and "properties" in node_b: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[node_b["properties"]["entity_id"]], + properties=node_b["properties"], + ) + + # 处理边 r + if result.get("r") and isinstance(result["r"], dict): + edge = result["r"] + edge_id = str(edge["id"]) + if edge_id not in edges_dict: + edges_dict[edge_id] = KnowledgeGraphEdge( + id=edge_id, + type=edge["label"], + source=str(edge["start_id"]), + target=str(edge["end_id"]), + properties=edge["properties"], + ) + + kg = KnowledgeGraph( + nodes=list(nodes_dict.values()), + edges=list(edges_dict.values()), + is_truncated=is_truncated, + ) + else: + # For single node query, use BFS algorithm + kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) + + logger.info( + f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" ) else: - # For single node query, use BFS algorithm + # 非通配符查询,使用 BFS 算法 kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) + logger.info( + f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" + ) - logger.info( - f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" - ) return kg async def drop(self) -> dict[str, str]: