Improve graph query speed by batch operation

This commit is contained in:
yangdx
2025-04-25 16:55:47 +08:00
parent 5ca430adcd
commit e3c87dd6bd

View File

@@ -1881,77 +1881,127 @@ class PGGraphStorage(BaseGraphStorage):
result.is_truncated = False result.is_truncated = False
# BFS search main loop
while queue: while queue:
# Dequeue the next node to process from the front of the queue # Get all nodes at the current depth
current_node, current_depth = queue.popleft() current_level_nodes = []
current_depth = None
# Check one more depth for backward edges
# Determine current depth
if queue:
current_depth = queue[0][1]
# Extract all nodes at current depth from the queue
while queue and queue[0][1] == current_depth:
node, depth = queue.popleft()
if depth > max_depth:
continue
current_level_nodes.append(node)
if not current_level_nodes:
continue
# Check depth limit
if current_depth > max_depth: if current_depth > max_depth:
continue continue
# Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency # Prepare node IDs list
current_entity_id = current_node.labels[0] node_ids = [node.labels[0] for node in current_level_nodes]
outgoing_query = """SELECT * FROM cypher('%s', $$ formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
MATCH (a:base {entity_id: "%s"})-[r]->(b)
WITH r, b, id(r) as edge_id, id(b) as target_id # Construct batch query for outgoing edges
RETURN r, b, edge_id, target_id outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % ( UNWIND [{formatted_ids}] AS node_id
self.graph_name, MATCH (n:base {{entity_id: node_id}})
current_entity_id, OPTIONAL MATCH (n)-[r]->(neighbor:base)
) RETURN node_id AS current_id,
incoming_query = """SELECT * FROM cypher('%s', $$ id(n) AS current_internal_id,
MATCH (a:base {entity_id: "%s"})<-[r]-(b) id(neighbor) AS neighbor_internal_id,
WITH r, b, id(r) as edge_id, id(b) as target_id neighbor.entity_id AS neighbor_id,
RETURN r, b, edge_id, target_id id(r) AS edge_id,
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % ( r,
self.graph_name, neighbor,
current_entity_id, true AS is_outgoing
) $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
outgoing_neighbors = await self._query(outgoing_query)
incoming_neighbors = await self._query(incoming_query) # Construct batch query for incoming edges
neighbors = outgoing_neighbors + incoming_neighbors incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id
# logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})") MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
RETURN node_id AS current_id,
id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id,
id(r) AS edge_id,
r,
neighbor,
false AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Execute queries
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
# Combine results
neighbors = outgoing_results + incoming_results
# Create mapping from node ID to node object
node_map = {node.labels[0]: node for node in current_level_nodes}
# Process all results in a single loop
for record in neighbors: for record in neighbors:
if not record.get("b") or not record.get("r"): if not record.get("neighbor") or not record.get("r"):
continue continue
b_node = record["b"] # Get current node information
current_entity_id = record["current_id"]
current_node = node_map[current_entity_id]
# Get neighbor node information
neighbor_entity_id = record["neighbor_id"]
neighbor_internal_id = str(record["neighbor_internal_id"])
is_outgoing = record["is_outgoing"]
# Determine edge direction
if is_outgoing:
source_id = current_node.id
target_id = neighbor_internal_id
else:
source_id = neighbor_internal_id
target_id = current_node.id
if not neighbor_entity_id:
continue
# Get edge and node information
b_node = record["neighbor"]
rel = record["r"] rel = record["r"]
edge_id = str(record["edge_id"]) edge_id = str(record["edge_id"])
if ( # Create neighbor node object
"properties" not in b_node neighbor_node = KnowledgeGraphNode(
or "entity_id" not in b_node["properties"] id=neighbor_internal_id,
): labels=[neighbor_entity_id],
continue
target_entity_id = b_node["properties"]["entity_id"]
target_internal_id = str(b_node["id"])
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=target_internal_id,
labels=[target_entity_id],
properties=b_node["properties"], properties=b_node["properties"],
) )
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
sorted_pair = tuple(sorted([current_entity_id, target_entity_id])) sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
# Create edge object # Create edge object
edge = KnowledgeGraphEdge( edge = KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type=rel["label"], type=rel["label"],
source=current_node.id, source=source_id,
target=target_internal_id, target=target_id,
properties=rel["properties"], properties=rel["properties"],
) )
if target_internal_id in visited_node_ids: if neighbor_internal_id in visited_node_ids:
# Add backward edge if target node is visited # Add backward edge if neighbor node is already visited
if ( if (
edge_id not in visited_edges edge_id not in visited_edges
and sorted_pair not in visited_edge_pairs and sorted_pair not in visited_edge_pairs
@@ -1959,17 +2009,16 @@ class PGGraphStorage(BaseGraphStorage):
result.edges.append(edge) result.edges.append(edge)
visited_edges.add(edge_id) visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair) visited_edge_pairs.add(sorted_pair)
else: else:
if len(visited_node_ids) < max_nodes and current_depth < max_depth: if len(visited_node_ids) < max_nodes and current_depth < max_depth:
# If target node not yet visited, add to result and queue # Add new node to result and queue
result.nodes.append(target_node) result.nodes.append(neighbor_node)
visited_nodes.add(target_entity_id) visited_nodes.add(neighbor_entity_id)
visited_node_ids.add(target_internal_id) visited_node_ids.add(neighbor_internal_id)
# Add node to queue with incremented depth # Add node to queue with incremented depth
queue.append((target_node, current_depth + 1)) queue.append((neighbor_node, current_depth + 1))
# Add forward edge # Add forward edge
if ( if (
edge_id not in visited_edges edge_id not in visited_edges
@@ -1978,7 +2027,6 @@ class PGGraphStorage(BaseGraphStorage):
result.edges.append(edge) result.edges.append(edge)
visited_edges.add(edge_id) visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair) visited_edge_pairs.add(sorted_pair)
# logger.info(f"Forward edge from {current_entity_id} to {target_entity_id}")
else: else:
if current_depth < max_depth: if current_depth < max_depth:
result.is_truncated = True result.is_truncated = True