From 6f498a678c3bb469d61ed42f01f6f48b230cd015 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 13 Apr 2025 01:07:07 +0800 Subject: [PATCH] Implement batch query funtions for PGGraphStorage of PostgreSQl AGE graph storage --- lightrag/kg/postgres_impl.py | 191 +++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index c067a0d0..f567055c 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1458,6 +1458,197 @@ class PGGraphStorage(BaseGraphStorage): logger.error(f"Error during edge deletion: {str(e)}") raise + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + """ + Retrieve multiple nodes in one query using UNWIND. + + Args: + node_ids: List of node entity IDs to fetch. + + Returns: + A dictionary mapping each node_id to its node data (or None if not found). + """ + if not node_ids: + return {} + + # Format node IDs for the query + formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) + + query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + RETURN node_id, n + $$) AS (node_id text, n agtype)""" % ( + self.graph_name, + formatted_ids + ) + + results = await self._query(query) + + # Build result dictionary + nodes_dict = {} + for result in results: + if result["node_id"] and result["n"]: + node_dict = result["n"]["properties"] + # Remove the 'base' label if present in a 'labels' property + if "labels" in node_dict: + node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + nodes_dict[result["node_id"]] = node_dict + + return nodes_dict + + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + """ + Retrieve the degree for multiple nodes in a single query using UNWIND. + + Args: + node_ids: List of node labels (entity_id values) to look up. + + Returns: + A dictionary mapping each node_id to its degree (number of relationships). + If a node is not found, its degree will be set to 0. + """ + if not node_ids: + return {} + + # Format node IDs for the query + formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) + + query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n)-[r]-() + RETURN node_id, count(r) AS degree + $$) AS (node_id text, degree bigint)""" % ( + self.graph_name, + formatted_ids + ) + + results = await self._query(query) + + # Build result dictionary + degrees_dict = {} + for result in results: + if result["node_id"] is not None: + degrees_dict[result["node_id"]] = int(result["degree"]) + + # Ensure all requested node_ids are in the result dictionary + for node_id in node_ids: + if node_id not in degrees_dict: + degrees_dict[node_id] = 0 + + return degrees_dict + + async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]: + """ + Calculate the combined degree for each edge (sum of the source and target node degrees) + in batch using the already implemented node_degrees_batch. + + Args: + edges: List of (source_node_id, target_node_id) tuples + + Returns: + Dictionary mapping edge tuples to their combined degrees + """ + if not edges: + return {} + + # Use node_degrees_batch to get all node degrees efficiently + all_nodes = set() + for src, tgt in edges: + all_nodes.add(src) + all_nodes.add(tgt) + + node_degrees = await self.node_degrees_batch(list(all_nodes)) + + # Calculate edge degrees + edge_degrees_dict = {} + for src, tgt in edges: + src_degree = node_degrees.get(src, 0) + tgt_degree = node_degrees.get(tgt, 0) + edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree + + return edge_degrees_dict + + async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: + """ + Retrieve edge properties for multiple (src, tgt) pairs in one query. + + Args: + pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] + + Returns: + A dictionary mapping (src, tgt) tuples to their edge properties. + """ + if not pairs: + return {} + + # 从字典列表中提取源节点和目标节点ID + src_nodes = [] + tgt_nodes = [] + for pair in pairs: + src_nodes.append(pair["src"].replace('"', '')) + tgt_nodes.append(pair["tgt"].replace('"', '')) + + # 构建查询,使用数组索引来匹配源节点和目标节点 + src_array = ", ".join([f'"{src}"' for src in src_nodes]) + tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) + + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + WITH [{src_array}] AS sources, [{tgt_array}] AS targets + UNWIND range(0, size(sources)-1) AS i + MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}}) + RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties + $$) AS (source text, target text, edge_properties agtype)""" + + results = await self._query(query) + + # 构建结果字典 + edges_dict = {} + for result in results: + if result["source"] and result["target"] and result["edge_properties"]: + edges_dict[(result["source"], result["target"])] = result["edge_properties"] + + return edges_dict + + async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: + """ + Get all edges for multiple nodes in a single batch operation. + + Args: + node_ids: List of node IDs to get edges for + + Returns: + Dictionary mapping node IDs to lists of (source, target) edge tuples + """ + if not node_ids: + return {} + + # Format node IDs for the query + formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) + + query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n)-[]-(connected:base) + RETURN node_id, connected.entity_id AS connected_id + $$) AS (node_id text, connected_id text)""" % ( + self.graph_name, + formatted_ids + ) + + results = await self._query(query) + + # Build result dictionary + nodes_edges_dict = {node_id: [] for node_id in node_ids} + for result in results: + if result["node_id"] and result["connected_id"]: + nodes_edges_dict[result["node_id"]].append( + (result["node_id"], result["connected_id"]) + ) + + return nodes_edges_dict + async def get_all_labels(self) -> list[str]: """ Get all labels (node IDs) in the graph.