diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 7c52b178..4b6438ae 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1170,9 +1170,6 @@ class PGGraphStorage(BaseGraphStorage): Returns: list[dict[str, Any]]: a list of dictionaries containing the result set """ - - logger.info(f"Executing graph query: {query}") - try: if readonly: data = await self.db.query( @@ -1255,8 +1252,8 @@ class PGGraphStorage(BaseGraphStorage): label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"})-[]-(x) - RETURN count(x) AS total_edge_count + MATCH (n:base {entity_id: "%s"})-[r]-() + RETURN count(r) AS total_edge_count $$) AS (total_edge_count integer)""" % (self.graph_name, label) record = (await self._query(query))[0] if record: @@ -1523,12 +1520,14 @@ class PGGraphStorage(BaseGraphStorage): async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. + Calculates the total degree by counting distinct relationships. + Uses separate queries for outgoing and incoming edges. Args: node_ids: List of node labels (entity_id values) to look up. Returns: - A dictionary mapping each node_id to its degree (number of relationships). + A dictionary mapping each node_id to its degree (total number of relationships). If a node is not found, its degree will be set to 0. """ if not node_ids: @@ -1539,28 +1538,45 @@ class PGGraphStorage(BaseGraphStorage): ['"' + node_id.replace('"', "") + '"' for node_id in node_ids] ) - query = """SELECT * FROM cypher('%s', $$ + outgoing_query = """SELECT * FROM cypher('%s', $$ UNWIND [%s] AS node_id MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n)-[r]->() - RETURN node_id, count(r) AS degree - $$) AS (node_id text, degree bigint)""" % ( + OPTIONAL MATCH (n)-[r]->(a) + RETURN node_id, count(a) AS out_degree + $$) AS (node_id text, out_degree bigint)""" % ( self.graph_name, formatted_ids, ) - results = await self._query(query) + incoming_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n)<-[r]-(b) + RETURN node_id, count(b) AS in_degree + $$) AS (node_id text, in_degree bigint)""" % ( + self.graph_name, + formatted_ids, + ) - # Build result dictionary - degrees_dict = {} - for result in results: + outgoing_results = await self._query(outgoing_query) + incoming_results = await self._query(incoming_query) + + out_degrees = {} + in_degrees = {} + + for result in outgoing_results: if result["node_id"] is not None: - degrees_dict[result["node_id"]] = int(result["degree"]) - - # Ensure all requested node_ids are in the result dictionary + out_degrees[result["node_id"]] = int(result["out_degree"]) + + for result in incoming_results: + if result["node_id"] is not None: + in_degrees[result["node_id"]] = int(result["in_degree"]) + + degrees_dict = {} for node_id in node_ids: - if node_id not in degrees_dict: - degrees_dict[node_id] = 0 + out_degree = out_degrees.get(node_id, 0) + in_degree = in_degrees.get(node_id, 0) + degrees_dict[node_id] = out_degree + in_degree return degrees_dict @@ -1602,6 +1618,7 @@ class PGGraphStorage(BaseGraphStorage): ) -> dict[tuple[str, str], dict]: """ Retrieve edge properties for multiple (src, tgt) pairs in one query. + Get forward and backward edges seperately and merge them before return Args: pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] @@ -1612,33 +1629,41 @@ class PGGraphStorage(BaseGraphStorage): if not pairs: return {} - # 从字典列表中提取源节点和目标节点ID src_nodes = [] tgt_nodes = [] for pair in pairs: src_nodes.append(pair["src"].replace('"', "")) tgt_nodes.append(pair["tgt"].replace('"', "")) - # 构建查询,使用数组索引来匹配源节点和目标节点 src_array = ", ".join([f'"{src}"' for src in src_nodes]) tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ WITH [{src_array}] AS sources, [{tgt_array}] AS targets UNWIND range(0, size(sources)-1) AS i MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}}) RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties $$) AS (source text, target text, edge_properties agtype)""" - results = await self._query(query) + backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + WITH [{src_array}] AS sources, [{tgt_array}] AS targets + UNWIND range(0, size(sources)-1) AS i + MATCH (a:base {{entity_id: sources[i]}})<-[r:DIRECTED]-(b:base {{entity_id: targets[i]}}) + RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties + $$) AS (source text, target text, edge_properties agtype)""" + + forward_results = await self._query(forward_query) + backward_results = await self._query(backward_query) - # 构建结果字典 edges_dict = {} - for result in results: + + for result in forward_results: if result["source"] and result["target"] and result["edge_properties"]: - edges_dict[(result["source"], result["target"])] = result[ - "edge_properties" - ] + edges_dict[(result["source"], result["target"])] = result["edge_properties"] + + for result in backward_results: + if result["source"] and result["target"] and result["edge_properties"]: + edges_dict[(result["source"], result["target"])] = result["edge_properties"] return edges_dict @@ -1646,7 +1671,7 @@ class PGGraphStorage(BaseGraphStorage): self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: """ - Get all edges for multiple nodes in a single batch operation. + Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation. Args: node_ids: List of node IDs to get edges for @@ -1662,7 +1687,7 @@ class PGGraphStorage(BaseGraphStorage): ['"' + node_id.replace('"', "") + '"' for node_id in node_ids] ) - query = """SELECT * FROM cypher('%s', $$ + outgoing_query = """SELECT * FROM cypher('%s', $$ UNWIND [%s] AS node_id MATCH (n:base {entity_id: node_id}) OPTIONAL MATCH (n:base)-[]->(connected:base) @@ -1672,15 +1697,32 @@ class PGGraphStorage(BaseGraphStorage): formatted_ids, ) - results = await self._query(query) + incoming_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n:base)<-[]-(connected:base) + RETURN node_id, connected.entity_id AS connected_id + $$) AS (node_id text, connected_id text)""" % ( + self.graph_name, + formatted_ids, + ) + + outgoing_results = await self._query(outgoing_query) + incoming_results = await self._query(incoming_query) - # Build result dictionary nodes_edges_dict = {node_id: [] for node_id in node_ids} - for result in results: + + for result in outgoing_results: if result["node_id"] and result["connected_id"]: nodes_edges_dict[result["node_id"]].append( (result["node_id"], result["connected_id"]) ) + + for result in incoming_results: + if result["node_id"] and result["connected_id"]: + nodes_edges_dict[result["node_id"]].append( + (result["connected_id"], result["node_id"]) + ) return nodes_edges_dict