Improve graph query speed by batch operation
This commit is contained in:
@@ -1881,77 +1881,127 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
result.is_truncated = False
|
result.is_truncated = False
|
||||||
|
|
||||||
|
# BFS search main loop
|
||||||
while queue:
|
while queue:
|
||||||
# Dequeue the next node to process from the front of the queue
|
# Get all nodes at the current depth
|
||||||
current_node, current_depth = queue.popleft()
|
current_level_nodes = []
|
||||||
|
current_depth = None
|
||||||
# Check one more depth for backward edges
|
|
||||||
|
# Determine current depth
|
||||||
|
if queue:
|
||||||
|
current_depth = queue[0][1]
|
||||||
|
|
||||||
|
# Extract all nodes at current depth from the queue
|
||||||
|
while queue and queue[0][1] == current_depth:
|
||||||
|
node, depth = queue.popleft()
|
||||||
|
if depth > max_depth:
|
||||||
|
continue
|
||||||
|
current_level_nodes.append(node)
|
||||||
|
|
||||||
|
if not current_level_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check depth limit
|
||||||
if current_depth > max_depth:
|
if current_depth > max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency
|
# Prepare node IDs list
|
||||||
current_entity_id = current_node.labels[0]
|
node_ids = [node.labels[0] for node in current_level_nodes]
|
||||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
|
||||||
MATCH (a:base {entity_id: "%s"})-[r]->(b)
|
|
||||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
# Construct batch query for outgoing edges
|
||||||
RETURN r, b, edge_id, target_id
|
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
|
UNWIND [{formatted_ids}] AS node_id
|
||||||
self.graph_name,
|
MATCH (n:base {{entity_id: node_id}})
|
||||||
current_entity_id,
|
OPTIONAL MATCH (n)-[r]->(neighbor:base)
|
||||||
)
|
RETURN node_id AS current_id,
|
||||||
incoming_query = """SELECT * FROM cypher('%s', $$
|
id(n) AS current_internal_id,
|
||||||
MATCH (a:base {entity_id: "%s"})<-[r]-(b)
|
id(neighbor) AS neighbor_internal_id,
|
||||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
neighbor.entity_id AS neighbor_id,
|
||||||
RETURN r, b, edge_id, target_id
|
id(r) AS edge_id,
|
||||||
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
|
r,
|
||||||
self.graph_name,
|
neighbor,
|
||||||
current_entity_id,
|
true AS is_outgoing
|
||||||
)
|
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||||
|
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||||
outgoing_neighbors = await self._query(outgoing_query)
|
|
||||||
incoming_neighbors = await self._query(incoming_query)
|
# Construct batch query for incoming edges
|
||||||
neighbors = outgoing_neighbors + incoming_neighbors
|
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
|
UNWIND [{formatted_ids}] AS node_id
|
||||||
# logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})")
|
MATCH (n:base {{entity_id: node_id}})
|
||||||
|
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
|
||||||
|
RETURN node_id AS current_id,
|
||||||
|
id(n) AS current_internal_id,
|
||||||
|
id(neighbor) AS neighbor_internal_id,
|
||||||
|
neighbor.entity_id AS neighbor_id,
|
||||||
|
id(r) AS edge_id,
|
||||||
|
r,
|
||||||
|
neighbor,
|
||||||
|
false AS is_outgoing
|
||||||
|
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||||
|
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||||
|
|
||||||
|
# Execute queries
|
||||||
|
outgoing_results = await self._query(outgoing_query)
|
||||||
|
incoming_results = await self._query(incoming_query)
|
||||||
|
|
||||||
|
# Combine results
|
||||||
|
neighbors = outgoing_results + incoming_results
|
||||||
|
|
||||||
|
# Create mapping from node ID to node object
|
||||||
|
node_map = {node.labels[0]: node for node in current_level_nodes}
|
||||||
|
|
||||||
|
# Process all results in a single loop
|
||||||
for record in neighbors:
|
for record in neighbors:
|
||||||
if not record.get("b") or not record.get("r"):
|
if not record.get("neighbor") or not record.get("r"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
b_node = record["b"]
|
# Get current node information
|
||||||
|
current_entity_id = record["current_id"]
|
||||||
|
current_node = node_map[current_entity_id]
|
||||||
|
|
||||||
|
# Get neighbor node information
|
||||||
|
neighbor_entity_id = record["neighbor_id"]
|
||||||
|
neighbor_internal_id = str(record["neighbor_internal_id"])
|
||||||
|
is_outgoing = record["is_outgoing"]
|
||||||
|
|
||||||
|
# Determine edge direction
|
||||||
|
if is_outgoing:
|
||||||
|
source_id = current_node.id
|
||||||
|
target_id = neighbor_internal_id
|
||||||
|
else:
|
||||||
|
source_id = neighbor_internal_id
|
||||||
|
target_id = current_node.id
|
||||||
|
|
||||||
|
if not neighbor_entity_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get edge and node information
|
||||||
|
b_node = record["neighbor"]
|
||||||
rel = record["r"]
|
rel = record["r"]
|
||||||
edge_id = str(record["edge_id"])
|
edge_id = str(record["edge_id"])
|
||||||
|
|
||||||
if (
|
# Create neighbor node object
|
||||||
"properties" not in b_node
|
neighbor_node = KnowledgeGraphNode(
|
||||||
or "entity_id" not in b_node["properties"]
|
id=neighbor_internal_id,
|
||||||
):
|
labels=[neighbor_entity_id],
|
||||||
continue
|
|
||||||
|
|
||||||
target_entity_id = b_node["properties"]["entity_id"]
|
|
||||||
target_internal_id = str(b_node["id"])
|
|
||||||
|
|
||||||
# Create KnowledgeGraphNode for target
|
|
||||||
target_node = KnowledgeGraphNode(
|
|
||||||
id=target_internal_id,
|
|
||||||
labels=[target_entity_id],
|
|
||||||
properties=b_node["properties"],
|
properties=b_node["properties"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
||||||
sorted_pair = tuple(sorted([current_entity_id, target_entity_id]))
|
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
|
||||||
|
|
||||||
# Create edge object
|
# Create edge object
|
||||||
edge = KnowledgeGraphEdge(
|
edge = KnowledgeGraphEdge(
|
||||||
id=edge_id,
|
id=edge_id,
|
||||||
type=rel["label"],
|
type=rel["label"],
|
||||||
source=current_node.id,
|
source=source_id,
|
||||||
target=target_internal_id,
|
target=target_id,
|
||||||
properties=rel["properties"],
|
properties=rel["properties"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if target_internal_id in visited_node_ids:
|
if neighbor_internal_id in visited_node_ids:
|
||||||
# Add backward edge if target node is visited
|
# Add backward edge if neighbor node is already visited
|
||||||
if (
|
if (
|
||||||
edge_id not in visited_edges
|
edge_id not in visited_edges
|
||||||
and sorted_pair not in visited_edge_pairs
|
and sorted_pair not in visited_edge_pairs
|
||||||
@@ -1959,17 +2009,16 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
result.edges.append(edge)
|
result.edges.append(edge)
|
||||||
visited_edges.add(edge_id)
|
visited_edges.add(edge_id)
|
||||||
visited_edge_pairs.add(sorted_pair)
|
visited_edge_pairs.add(sorted_pair)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if len(visited_node_ids) < max_nodes and current_depth < max_depth:
|
if len(visited_node_ids) < max_nodes and current_depth < max_depth:
|
||||||
# If target node not yet visited, add to result and queue
|
# Add new node to result and queue
|
||||||
result.nodes.append(target_node)
|
result.nodes.append(neighbor_node)
|
||||||
visited_nodes.add(target_entity_id)
|
visited_nodes.add(neighbor_entity_id)
|
||||||
visited_node_ids.add(target_internal_id)
|
visited_node_ids.add(neighbor_internal_id)
|
||||||
|
|
||||||
# Add node to queue with incremented depth
|
# Add node to queue with incremented depth
|
||||||
queue.append((target_node, current_depth + 1))
|
queue.append((neighbor_node, current_depth + 1))
|
||||||
|
|
||||||
# Add forward edge
|
# Add forward edge
|
||||||
if (
|
if (
|
||||||
edge_id not in visited_edges
|
edge_id not in visited_edges
|
||||||
@@ -1978,7 +2027,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
result.edges.append(edge)
|
result.edges.append(edge)
|
||||||
visited_edges.add(edge_id)
|
visited_edges.add(edge_id)
|
||||||
visited_edge_pairs.add(sorted_pair)
|
visited_edge_pairs.add(sorted_pair)
|
||||||
# logger.info(f"Forward edge from {current_entity_id} to {target_entity_id}")
|
|
||||||
else:
|
else:
|
||||||
if current_depth < max_depth:
|
if current_depth < max_depth:
|
||||||
result.is_truncated = True
|
result.is_truncated = True
|
||||||
|
Reference in New Issue
Block a user