Improve graph query speed by batch operation
This commit is contained in:
@@ -1881,77 +1881,127 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
result.is_truncated = False
|
||||
|
||||
# BFS search main loop
|
||||
while queue:
|
||||
# Dequeue the next node to process from the front of the queue
|
||||
current_node, current_depth = queue.popleft()
|
||||
|
||||
# Check one more depth for backward edges
|
||||
# Get all nodes at the current depth
|
||||
current_level_nodes = []
|
||||
current_depth = None
|
||||
|
||||
# Determine current depth
|
||||
if queue:
|
||||
current_depth = queue[0][1]
|
||||
|
||||
# Extract all nodes at current depth from the queue
|
||||
while queue and queue[0][1] == current_depth:
|
||||
node, depth = queue.popleft()
|
||||
if depth > max_depth:
|
||||
continue
|
||||
current_level_nodes.append(node)
|
||||
|
||||
if not current_level_nodes:
|
||||
continue
|
||||
|
||||
# Check depth limit
|
||||
if current_depth > max_depth:
|
||||
continue
|
||||
|
||||
# Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency
|
||||
current_entity_id = current_node.labels[0]
|
||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {entity_id: "%s"})-[r]->(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
|
||||
self.graph_name,
|
||||
current_entity_id,
|
||||
)
|
||||
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {entity_id: "%s"})<-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
|
||||
self.graph_name,
|
||||
current_entity_id,
|
||||
)
|
||||
|
||||
outgoing_neighbors = await self._query(outgoing_query)
|
||||
incoming_neighbors = await self._query(incoming_query)
|
||||
neighbors = outgoing_neighbors + incoming_neighbors
|
||||
|
||||
# logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})")
|
||||
|
||||
|
||||
# Prepare node IDs list
|
||||
node_ids = [node.labels[0] for node in current_level_nodes]
|
||||
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
|
||||
|
||||
# Construct batch query for outgoing edges
|
||||
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND [{formatted_ids}] AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
OPTIONAL MATCH (n)-[r]->(neighbor:base)
|
||||
RETURN node_id AS current_id,
|
||||
id(n) AS current_internal_id,
|
||||
id(neighbor) AS neighbor_internal_id,
|
||||
neighbor.entity_id AS neighbor_id,
|
||||
id(r) AS edge_id,
|
||||
r,
|
||||
neighbor,
|
||||
true AS is_outgoing
|
||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||
|
||||
# Construct batch query for incoming edges
|
||||
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND [{formatted_ids}] AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
|
||||
RETURN node_id AS current_id,
|
||||
id(n) AS current_internal_id,
|
||||
id(neighbor) AS neighbor_internal_id,
|
||||
neighbor.entity_id AS neighbor_id,
|
||||
id(r) AS edge_id,
|
||||
r,
|
||||
neighbor,
|
||||
false AS is_outgoing
|
||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||
|
||||
# Execute queries
|
||||
outgoing_results = await self._query(outgoing_query)
|
||||
incoming_results = await self._query(incoming_query)
|
||||
|
||||
# Combine results
|
||||
neighbors = outgoing_results + incoming_results
|
||||
|
||||
# Create mapping from node ID to node object
|
||||
node_map = {node.labels[0]: node for node in current_level_nodes}
|
||||
|
||||
# Process all results in a single loop
|
||||
for record in neighbors:
|
||||
if not record.get("b") or not record.get("r"):
|
||||
if not record.get("neighbor") or not record.get("r"):
|
||||
continue
|
||||
|
||||
b_node = record["b"]
|
||||
|
||||
# Get current node information
|
||||
current_entity_id = record["current_id"]
|
||||
current_node = node_map[current_entity_id]
|
||||
|
||||
# Get neighbor node information
|
||||
neighbor_entity_id = record["neighbor_id"]
|
||||
neighbor_internal_id = str(record["neighbor_internal_id"])
|
||||
is_outgoing = record["is_outgoing"]
|
||||
|
||||
# Determine edge direction
|
||||
if is_outgoing:
|
||||
source_id = current_node.id
|
||||
target_id = neighbor_internal_id
|
||||
else:
|
||||
source_id = neighbor_internal_id
|
||||
target_id = current_node.id
|
||||
|
||||
if not neighbor_entity_id:
|
||||
continue
|
||||
|
||||
# Get edge and node information
|
||||
b_node = record["neighbor"]
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
|
||||
if (
|
||||
"properties" not in b_node
|
||||
or "entity_id" not in b_node["properties"]
|
||||
):
|
||||
continue
|
||||
|
||||
target_entity_id = b_node["properties"]["entity_id"]
|
||||
target_internal_id = str(b_node["id"])
|
||||
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=target_internal_id,
|
||||
labels=[target_entity_id],
|
||||
|
||||
# Create neighbor node object
|
||||
neighbor_node = KnowledgeGraphNode(
|
||||
id=neighbor_internal_id,
|
||||
labels=[neighbor_entity_id],
|
||||
properties=b_node["properties"],
|
||||
)
|
||||
|
||||
|
||||
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
||||
sorted_pair = tuple(sorted([current_entity_id, target_entity_id]))
|
||||
|
||||
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
|
||||
|
||||
# Create edge object
|
||||
edge = KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=rel["label"],
|
||||
source=current_node.id,
|
||||
target=target_internal_id,
|
||||
source=source_id,
|
||||
target=target_id,
|
||||
properties=rel["properties"],
|
||||
)
|
||||
|
||||
if target_internal_id in visited_node_ids:
|
||||
# Add backward edge if target node is visited
|
||||
|
||||
if neighbor_internal_id in visited_node_ids:
|
||||
# Add backward edge if neighbor node is already visited
|
||||
if (
|
||||
edge_id not in visited_edges
|
||||
and sorted_pair not in visited_edge_pairs
|
||||
@@ -1959,17 +2009,16 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
result.edges.append(edge)
|
||||
visited_edges.add(edge_id)
|
||||
visited_edge_pairs.add(sorted_pair)
|
||||
|
||||
else:
|
||||
if len(visited_node_ids) < max_nodes and current_depth < max_depth:
|
||||
# If target node not yet visited, add to result and queue
|
||||
result.nodes.append(target_node)
|
||||
visited_nodes.add(target_entity_id)
|
||||
visited_node_ids.add(target_internal_id)
|
||||
|
||||
# Add new node to result and queue
|
||||
result.nodes.append(neighbor_node)
|
||||
visited_nodes.add(neighbor_entity_id)
|
||||
visited_node_ids.add(neighbor_internal_id)
|
||||
|
||||
# Add node to queue with incremented depth
|
||||
queue.append((target_node, current_depth + 1))
|
||||
|
||||
queue.append((neighbor_node, current_depth + 1))
|
||||
|
||||
# Add forward edge
|
||||
if (
|
||||
edge_id not in visited_edges
|
||||
@@ -1978,7 +2027,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
result.edges.append(edge)
|
||||
visited_edges.add(edge_id)
|
||||
visited_edge_pairs.add(sorted_pair)
|
||||
# logger.info(f"Forward edge from {current_entity_id} to {target_entity_id}")
|
||||
else:
|
||||
if current_depth < max_depth:
|
||||
result.is_truncated = True
|
||||
|
Reference in New Issue
Block a user