Improve graph query speed by batch operation

This commit is contained in:
yangdx
2025-04-25 16:55:47 +08:00
parent 5ca430adcd
commit e3c87dd6bd

View File

@@ -1881,77 +1881,127 @@ class PGGraphStorage(BaseGraphStorage):
result.is_truncated = False
# BFS search main loop
while queue:
# Dequeue the next node to process from the front of the queue
current_node, current_depth = queue.popleft()
# Check one more depth for backward edges
# Get all nodes at the current depth
current_level_nodes = []
current_depth = None
# Determine current depth
if queue:
current_depth = queue[0][1]
# Extract all nodes at current depth from the queue
while queue and queue[0][1] == current_depth:
node, depth = queue.popleft()
if depth > max_depth:
continue
current_level_nodes.append(node)
if not current_level_nodes:
continue
# Check depth limit
if current_depth > max_depth:
continue
# Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency
current_entity_id = current_node.labels[0]
outgoing_query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]->(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
self.graph_name,
current_entity_id,
)
incoming_query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})<-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
self.graph_name,
current_entity_id,
)
outgoing_neighbors = await self._query(outgoing_query)
incoming_neighbors = await self._query(incoming_query)
neighbors = outgoing_neighbors + incoming_neighbors
# logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})")
# Prepare node IDs list
node_ids = [node.labels[0] for node in current_level_nodes]
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
# Construct batch query for outgoing edges
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id
MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)-[r]->(neighbor:base)
RETURN node_id AS current_id,
id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id,
id(r) AS edge_id,
r,
neighbor,
true AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Construct batch query for incoming edges
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id
MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
RETURN node_id AS current_id,
id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id,
id(r) AS edge_id,
r,
neighbor,
false AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Execute queries
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
# Combine results
neighbors = outgoing_results + incoming_results
# Create mapping from node ID to node object
node_map = {node.labels[0]: node for node in current_level_nodes}
# Process all results in a single loop
for record in neighbors:
if not record.get("b") or not record.get("r"):
if not record.get("neighbor") or not record.get("r"):
continue
b_node = record["b"]
# Get current node information
current_entity_id = record["current_id"]
current_node = node_map[current_entity_id]
# Get neighbor node information
neighbor_entity_id = record["neighbor_id"]
neighbor_internal_id = str(record["neighbor_internal_id"])
is_outgoing = record["is_outgoing"]
# Determine edge direction
if is_outgoing:
source_id = current_node.id
target_id = neighbor_internal_id
else:
source_id = neighbor_internal_id
target_id = current_node.id
if not neighbor_entity_id:
continue
# Get edge and node information
b_node = record["neighbor"]
rel = record["r"]
edge_id = str(record["edge_id"])
if (
"properties" not in b_node
or "entity_id" not in b_node["properties"]
):
continue
target_entity_id = b_node["properties"]["entity_id"]
target_internal_id = str(b_node["id"])
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=target_internal_id,
labels=[target_entity_id],
# Create neighbor node object
neighbor_node = KnowledgeGraphNode(
id=neighbor_internal_id,
labels=[neighbor_entity_id],
properties=b_node["properties"],
)
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
sorted_pair = tuple(sorted([current_entity_id, target_entity_id]))
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
# Create edge object
edge = KnowledgeGraphEdge(
id=edge_id,
type=rel["label"],
source=current_node.id,
target=target_internal_id,
source=source_id,
target=target_id,
properties=rel["properties"],
)
if target_internal_id in visited_node_ids:
# Add backward edge if target node is visited
if neighbor_internal_id in visited_node_ids:
# Add backward edge if neighbor node is already visited
if (
edge_id not in visited_edges
and sorted_pair not in visited_edge_pairs
@@ -1959,17 +2009,16 @@ class PGGraphStorage(BaseGraphStorage):
result.edges.append(edge)
visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair)
else:
if len(visited_node_ids) < max_nodes and current_depth < max_depth:
# If target node not yet visited, add to result and queue
result.nodes.append(target_node)
visited_nodes.add(target_entity_id)
visited_node_ids.add(target_internal_id)
# Add new node to result and queue
result.nodes.append(neighbor_node)
visited_nodes.add(neighbor_entity_id)
visited_node_ids.add(neighbor_internal_id)
# Add node to queue with incremented depth
queue.append((target_node, current_depth + 1))
queue.append((neighbor_node, current_depth + 1))
# Add forward edge
if (
edge_id not in visited_edges
@@ -1978,7 +2027,6 @@ class PGGraphStorage(BaseGraphStorage):
result.edges.append(edge)
visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair)
# logger.info(f"Forward edge from {current_entity_id} to {target_entity_id}")
else:
if current_depth < max_depth:
result.is_truncated = True