Implement batch query funtions for PGGraphStorage of PostgreSQl AGE graph storage

This commit is contained in:
yangdx
2025-04-13 01:07:07 +08:00
parent 99f24cd51e
commit 6f498a678c

View File

@@ -1458,6 +1458,197 @@ class PGGraphStorage(BaseGraphStorage):
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
RETURN node_id, n
$$) AS (node_id text, n agtype)""" % (
self.graph_name,
formatted_ids
)
results = await self._query(query)
# Build result dictionary
nodes_dict = {}
for result in results:
if result["node_id"] and result["n"]:
node_dict = result["n"]["properties"]
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
nodes_dict[result["node_id"]] = node_dict
return nodes_dict
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0.
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)-[r]-()
RETURN node_id, count(r) AS degree
$$) AS (node_id text, degree bigint)""" % (
self.graph_name,
formatted_ids
)
results = await self._query(query)
# Build result dictionary
degrees_dict = {}
for result in results:
if result["node_id"] is not None:
degrees_dict[result["node_id"]] = int(result["degree"])
# Ensure all requested node_ids are in the result dictionary
for node_id in node_ids:
if node_id not in degrees_dict:
degrees_dict[node_id] = 0
return degrees_dict
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edges: List of (source_node_id, target_node_id) tuples
Returns:
Dictionary mapping edge tuples to their combined degrees
"""
if not edges:
return {}
# Use node_degrees_batch to get all node degrees efficiently
all_nodes = set()
for src, tgt in edges:
all_nodes.add(src)
all_nodes.add(tgt)
node_degrees = await self.node_degrees_batch(list(all_nodes))
# Calculate edge degrees
edge_degrees_dict = {}
for src, tgt in edges:
src_degree = node_degrees.get(src, 0)
tgt_degree = node_degrees.get(tgt, 0)
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
return edge_degrees_dict
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
if not pairs:
return {}
# 从字典列表中提取源节点和目标节点ID
src_nodes = []
tgt_nodes = []
for pair in pairs:
src_nodes.append(pair["src"].replace('"', ''))
tgt_nodes.append(pair["tgt"].replace('"', ''))
# 构建查询,使用数组索引来匹配源节点和目标节点
src_array = ", ".join([f'"{src}"' for src in src_nodes])
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
results = await self._query(query)
# 构建结果字典
edges_dict = {}
for result in results:
if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
return edges_dict
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
"""
Get all edges for multiple nodes in a single batch operation.
Args:
node_ids: List of node IDs to get edges for
Returns:
Dictionary mapping node IDs to lists of (source, target) edge tuples
"""
if not node_ids:
return {}
# Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)-[]-(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids
)
results = await self._query(query)
# Build result dictionary
nodes_edges_dict = {node_id: [] for node_id in node_ids}
for result in results:
if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append(
(result["node_id"], result["connected_id"])
)
return nodes_edges_dict
async def get_all_labels(self) -> list[str]:
"""
Get all labels (node IDs) in the graph.