Optimize PostgreSQL AGE graph storage performance by eperate forward and backward edge query
This commit is contained in:
@@ -1170,9 +1170,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Returns:
|
Returns:
|
||||||
list[dict[str, Any]]: a list of dictionaries containing the result set
|
list[dict[str, Any]]: a list of dictionaries containing the result set
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info(f"Executing graph query: {query}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if readonly:
|
if readonly:
|
||||||
data = await self.db.query(
|
data = await self.db.query(
|
||||||
@@ -1255,8 +1252,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
label = node_id.strip('"')
|
label = node_id.strip('"')
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
MATCH (n:base {entity_id: "%s"})-[]-(x)
|
MATCH (n:base {entity_id: "%s"})-[r]-()
|
||||||
RETURN count(x) AS total_edge_count
|
RETURN count(r) AS total_edge_count
|
||||||
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
||||||
record = (await self._query(query))[0]
|
record = (await self._query(query))[0]
|
||||||
if record:
|
if record:
|
||||||
@@ -1523,12 +1520,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||||
|
Calculates the total degree by counting distinct relationships.
|
||||||
|
Uses separate queries for outgoing and incoming edges.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_ids: List of node labels (entity_id values) to look up.
|
node_ids: List of node labels (entity_id values) to look up.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping each node_id to its degree (number of relationships).
|
A dictionary mapping each node_id to its degree (total number of relationships).
|
||||||
If a node is not found, its degree will be set to 0.
|
If a node is not found, its degree will be set to 0.
|
||||||
"""
|
"""
|
||||||
if not node_ids:
|
if not node_ids:
|
||||||
@@ -1539,28 +1538,45 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||||
)
|
)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||||
UNWIND [%s] AS node_id
|
UNWIND [%s] AS node_id
|
||||||
MATCH (n:base {entity_id: node_id})
|
MATCH (n:base {entity_id: node_id})
|
||||||
OPTIONAL MATCH (n)-[r]->()
|
OPTIONAL MATCH (n)-[r]->(a)
|
||||||
RETURN node_id, count(r) AS degree
|
RETURN node_id, count(a) AS out_degree
|
||||||
$$) AS (node_id text, degree bigint)""" % (
|
$$) AS (node_id text, out_degree bigint)""" % (
|
||||||
self.graph_name,
|
self.graph_name,
|
||||||
formatted_ids,
|
formatted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await self._query(query)
|
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||||
|
UNWIND [%s] AS node_id
|
||||||
|
MATCH (n:base {entity_id: node_id})
|
||||||
|
OPTIONAL MATCH (n)<-[r]-(b)
|
||||||
|
RETURN node_id, count(b) AS in_degree
|
||||||
|
$$) AS (node_id text, in_degree bigint)""" % (
|
||||||
|
self.graph_name,
|
||||||
|
formatted_ids,
|
||||||
|
)
|
||||||
|
|
||||||
# Build result dictionary
|
outgoing_results = await self._query(outgoing_query)
|
||||||
degrees_dict = {}
|
incoming_results = await self._query(incoming_query)
|
||||||
for result in results:
|
|
||||||
|
out_degrees = {}
|
||||||
|
in_degrees = {}
|
||||||
|
|
||||||
|
for result in outgoing_results:
|
||||||
if result["node_id"] is not None:
|
if result["node_id"] is not None:
|
||||||
degrees_dict[result["node_id"]] = int(result["degree"])
|
out_degrees[result["node_id"]] = int(result["out_degree"])
|
||||||
|
|
||||||
# Ensure all requested node_ids are in the result dictionary
|
for result in incoming_results:
|
||||||
|
if result["node_id"] is not None:
|
||||||
|
in_degrees[result["node_id"]] = int(result["in_degree"])
|
||||||
|
|
||||||
|
degrees_dict = {}
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id not in degrees_dict:
|
out_degree = out_degrees.get(node_id, 0)
|
||||||
degrees_dict[node_id] = 0
|
in_degree = in_degrees.get(node_id, 0)
|
||||||
|
degrees_dict[node_id] = out_degree + in_degree
|
||||||
|
|
||||||
return degrees_dict
|
return degrees_dict
|
||||||
|
|
||||||
@@ -1602,6 +1618,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
) -> dict[tuple[str, str], dict]:
|
) -> dict[tuple[str, str], dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||||
|
Get forward and backward edges seperately and merge them before return
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||||
@@ -1612,33 +1629,41 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
if not pairs:
|
if not pairs:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# 从字典列表中提取源节点和目标节点ID
|
|
||||||
src_nodes = []
|
src_nodes = []
|
||||||
tgt_nodes = []
|
tgt_nodes = []
|
||||||
for pair in pairs:
|
for pair in pairs:
|
||||||
src_nodes.append(pair["src"].replace('"', ""))
|
src_nodes.append(pair["src"].replace('"', ""))
|
||||||
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
||||||
|
|
||||||
# 构建查询,使用数组索引来匹配源节点和目标节点
|
|
||||||
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
||||||
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
||||||
|
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
||||||
UNWIND range(0, size(sources)-1) AS i
|
UNWIND range(0, size(sources)-1) AS i
|
||||||
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}})
|
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}})
|
||||||
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
||||||
$$) AS (source text, target text, edge_properties agtype)"""
|
$$) AS (source text, target text, edge_properties agtype)"""
|
||||||
|
|
||||||
results = await self._query(query)
|
backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
|
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
||||||
|
UNWIND range(0, size(sources)-1) AS i
|
||||||
|
MATCH (a:base {{entity_id: sources[i]}})<-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
|
||||||
|
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
||||||
|
$$) AS (source text, target text, edge_properties agtype)"""
|
||||||
|
|
||||||
|
forward_results = await self._query(forward_query)
|
||||||
|
backward_results = await self._query(backward_query)
|
||||||
|
|
||||||
# 构建结果字典
|
|
||||||
edges_dict = {}
|
edges_dict = {}
|
||||||
for result in results:
|
|
||||||
|
for result in forward_results:
|
||||||
if result["source"] and result["target"] and result["edge_properties"]:
|
if result["source"] and result["target"] and result["edge_properties"]:
|
||||||
edges_dict[(result["source"], result["target"])] = result[
|
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
|
||||||
"edge_properties"
|
|
||||||
]
|
for result in backward_results:
|
||||||
|
if result["source"] and result["target"] and result["edge_properties"]:
|
||||||
|
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
|
||||||
|
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
@@ -1646,7 +1671,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
self, node_ids: list[str]
|
self, node_ids: list[str]
|
||||||
) -> dict[str, list[tuple[str, str]]]:
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
"""
|
"""
|
||||||
Get all edges for multiple nodes in a single batch operation.
|
Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_ids: List of node IDs to get edges for
|
node_ids: List of node IDs to get edges for
|
||||||
@@ -1662,7 +1687,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||||
)
|
)
|
||||||
|
|
||||||
query = """SELECT * FROM cypher('%s', $$
|
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||||
UNWIND [%s] AS node_id
|
UNWIND [%s] AS node_id
|
||||||
MATCH (n:base {entity_id: node_id})
|
MATCH (n:base {entity_id: node_id})
|
||||||
OPTIONAL MATCH (n:base)-[]->(connected:base)
|
OPTIONAL MATCH (n:base)-[]->(connected:base)
|
||||||
@@ -1672,15 +1697,32 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
formatted_ids,
|
formatted_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await self._query(query)
|
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||||
|
UNWIND [%s] AS node_id
|
||||||
|
MATCH (n:base {entity_id: node_id})
|
||||||
|
OPTIONAL MATCH (n:base)<-[]-(connected:base)
|
||||||
|
RETURN node_id, connected.entity_id AS connected_id
|
||||||
|
$$) AS (node_id text, connected_id text)""" % (
|
||||||
|
self.graph_name,
|
||||||
|
formatted_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
outgoing_results = await self._query(outgoing_query)
|
||||||
|
incoming_results = await self._query(incoming_query)
|
||||||
|
|
||||||
# Build result dictionary
|
|
||||||
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
||||||
for result in results:
|
|
||||||
|
for result in outgoing_results:
|
||||||
if result["node_id"] and result["connected_id"]:
|
if result["node_id"] and result["connected_id"]:
|
||||||
nodes_edges_dict[result["node_id"]].append(
|
nodes_edges_dict[result["node_id"]].append(
|
||||||
(result["node_id"], result["connected_id"])
|
(result["node_id"], result["connected_id"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for result in incoming_results:
|
||||||
|
if result["node_id"] and result["connected_id"]:
|
||||||
|
nodes_edges_dict[result["node_id"]].append(
|
||||||
|
(result["connected_id"], result["node_id"])
|
||||||
|
)
|
||||||
|
|
||||||
return nodes_edges_dict
|
return nodes_edges_dict
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user