Optimize PostgreSQL AGE graph storage performance by eperate forward and backward edge query

This commit is contained in:
yangdx
2025-04-16 14:01:21 +08:00
parent c1ea47026d
commit 40c472c4c5

View File

@@ -1170,9 +1170,6 @@ class PGGraphStorage(BaseGraphStorage):
Returns: Returns:
list[dict[str, Any]]: a list of dictionaries containing the result set list[dict[str, Any]]: a list of dictionaries containing the result set
""" """
logger.info(f"Executing graph query: {query}")
try: try:
if readonly: if readonly:
data = await self.db.query( data = await self.db.query(
@@ -1255,8 +1252,8 @@ class PGGraphStorage(BaseGraphStorage):
label = node_id.strip('"') label = node_id.strip('"')
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})-[]-(x) MATCH (n:base {entity_id: "%s"})-[r]-()
RETURN count(x) AS total_edge_count RETURN count(r) AS total_edge_count
$$) AS (total_edge_count integer)""" % (self.graph_name, label) $$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0] record = (await self._query(query))[0]
if record: if record:
@@ -1523,12 +1520,14 @@ class PGGraphStorage(BaseGraphStorage):
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
Calculates the total degree by counting distinct relationships.
Uses separate queries for outgoing and incoming edges.
Args: Args:
node_ids: List of node labels (entity_id values) to look up. node_ids: List of node labels (entity_id values) to look up.
Returns: Returns:
A dictionary mapping each node_id to its degree (number of relationships). A dictionary mapping each node_id to its degree (total number of relationships).
If a node is not found, its degree will be set to 0. If a node is not found, its degree will be set to 0.
""" """
if not node_ids: if not node_ids:
@@ -1539,28 +1538,45 @@ class PGGraphStorage(BaseGraphStorage):
['"' + node_id.replace('"', "") + '"' for node_id in node_ids] ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
) )
query = """SELECT * FROM cypher('%s', $$ outgoing_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id}) MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)-[r]->() OPTIONAL MATCH (n)-[r]->(a)
RETURN node_id, count(r) AS degree RETURN node_id, count(a) AS out_degree
$$) AS (node_id text, degree bigint)""" % ( $$) AS (node_id text, out_degree bigint)""" % (
self.graph_name, self.graph_name,
formatted_ids, formatted_ids,
) )
results = await self._query(query) incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n)<-[r]-(b)
RETURN node_id, count(b) AS in_degree
$$) AS (node_id text, in_degree bigint)""" % (
self.graph_name,
formatted_ids,
)
# Build result dictionary outgoing_results = await self._query(outgoing_query)
degrees_dict = {} incoming_results = await self._query(incoming_query)
for result in results:
out_degrees = {}
in_degrees = {}
for result in outgoing_results:
if result["node_id"] is not None: if result["node_id"] is not None:
degrees_dict[result["node_id"]] = int(result["degree"]) out_degrees[result["node_id"]] = int(result["out_degree"])
# Ensure all requested node_ids are in the result dictionary for result in incoming_results:
if result["node_id"] is not None:
in_degrees[result["node_id"]] = int(result["in_degree"])
degrees_dict = {}
for node_id in node_ids: for node_id in node_ids:
if node_id not in degrees_dict: out_degree = out_degrees.get(node_id, 0)
degrees_dict[node_id] = 0 in_degree = in_degrees.get(node_id, 0)
degrees_dict[node_id] = out_degree + in_degree
return degrees_dict return degrees_dict
@@ -1602,6 +1618,7 @@ class PGGraphStorage(BaseGraphStorage):
) -> dict[tuple[str, str], dict]: ) -> dict[tuple[str, str], dict]:
""" """
Retrieve edge properties for multiple (src, tgt) pairs in one query. Retrieve edge properties for multiple (src, tgt) pairs in one query.
Get forward and backward edges seperately and merge them before return
Args: Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
@@ -1612,33 +1629,41 @@ class PGGraphStorage(BaseGraphStorage):
if not pairs: if not pairs:
return {} return {}
# 从字典列表中提取源节点和目标节点ID
src_nodes = [] src_nodes = []
tgt_nodes = [] tgt_nodes = []
for pair in pairs: for pair in pairs:
src_nodes.append(pair["src"].replace('"', "")) src_nodes.append(pair["src"].replace('"', ""))
tgt_nodes.append(pair["tgt"].replace('"', "")) tgt_nodes.append(pair["tgt"].replace('"', ""))
# 构建查询,使用数组索引来匹配源节点和目标节点
src_array = ", ".join([f'"{src}"' for src in src_nodes]) src_array = ", ".join([f'"{src}"' for src in src_nodes])
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{src_array}] AS sources, [{tgt_array}] AS targets WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}}) MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)""" $$) AS (source text, target text, edge_properties agtype)"""
results = await self._query(query) backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
UNWIND range(0, size(sources)-1) AS i
MATCH (a:base {{entity_id: sources[i]}})<-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
$$) AS (source text, target text, edge_properties agtype)"""
forward_results = await self._query(forward_query)
backward_results = await self._query(backward_query)
# 构建结果字典
edges_dict = {} edges_dict = {}
for result in results:
for result in forward_results:
if result["source"] and result["target"] and result["edge_properties"]: if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result[ edges_dict[(result["source"], result["target"])] = result["edge_properties"]
"edge_properties"
] for result in backward_results:
if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
return edges_dict return edges_dict
@@ -1646,7 +1671,7 @@ class PGGraphStorage(BaseGraphStorage):
self, node_ids: list[str] self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]: ) -> dict[str, list[tuple[str, str]]]:
""" """
Get all edges for multiple nodes in a single batch operation. Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
Args: Args:
node_ids: List of node IDs to get edges for node_ids: List of node IDs to get edges for
@@ -1662,7 +1687,7 @@ class PGGraphStorage(BaseGraphStorage):
['"' + node_id.replace('"', "") + '"' for node_id in node_ids] ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
) )
query = """SELECT * FROM cypher('%s', $$ outgoing_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id}) MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)-[]->(connected:base) OPTIONAL MATCH (n:base)-[]->(connected:base)
@@ -1672,15 +1697,32 @@ class PGGraphStorage(BaseGraphStorage):
formatted_ids, formatted_ids,
) )
results = await self._query(query) incoming_query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id})
OPTIONAL MATCH (n:base)<-[]-(connected:base)
RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % (
self.graph_name,
formatted_ids,
)
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
# Build result dictionary
nodes_edges_dict = {node_id: [] for node_id in node_ids} nodes_edges_dict = {node_id: [] for node_id in node_ids}
for result in results:
for result in outgoing_results:
if result["node_id"] and result["connected_id"]: if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append( nodes_edges_dict[result["node_id"]].append(
(result["node_id"], result["connected_id"]) (result["node_id"], result["connected_id"])
) )
for result in incoming_results:
if result["node_id"] and result["connected_id"]:
nodes_edges_dict[result["node_id"]].append(
(result["connected_id"], result["node_id"])
)
return nodes_edges_dict return nodes_edges_dict