Improve PostgreSQL AGE graph query by write a customized BFS implementation

This commit is contained in:
yangdx
2025-04-24 12:27:12 +08:00
parent 8d219ffa32
commit 11681fdd6b

View File

@@ -1824,6 +1824,159 @@ class PGGraphStorage(BaseGraphStorage):
labels.append(result["label"])
return labels
async def _bfs_subgraph(
self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
"""
Implements a true breadth-first search algorithm for subgraph retrieval.
This method is used as a fallback when the standard Cypher query is too slow
or when we need to guarantee BFS ordering.
Args:
node_label: Label of the starting node
max_depth: Maximum depth of the subgraph
max_nodes: Maximum number of nodes to return
Returns:
KnowledgeGraph object containing nodes and edges
"""
from collections import deque
result = KnowledgeGraph()
visited_nodes = set()
visited_node_ids = set()
visited_edges = set()
visited_edge_pairs = set()
# Get starting node data
label = self._normalize_node_id(node_label)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
RETURN id(n) as node_id, n
$$) AS (node_id bigint, n agtype)""" % (self.graph_name, label)
node_result = await self._query(query)
if not node_result or not node_result[0].get("n"):
return result
# Create initial KnowledgeGraphNode
start_node_data = node_result[0]["n"]
entity_id = start_node_data["properties"]["entity_id"]
internal_id = str(start_node_data["id"])
start_node = KnowledgeGraphNode(
id=internal_id,
labels=[entity_id],
properties=start_node_data["properties"],
)
# Initialize BFS queue, each element is a tuple of (node, depth)
queue = deque([(start_node, 0)])
visited_nodes.add(entity_id)
visited_node_ids.add(internal_id)
result.nodes.append(start_node)
while queue and len(visited_node_ids) < max_nodes:
# Dequeue the next node to process from the front of the queue
current_node, current_depth = queue.popleft()
if current_depth >= max_depth:
continue
# Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency
current_entity_id = current_node.labels[0]
outgoing_query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]->(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
self.graph_name,
current_entity_id,
)
incoming_query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})<-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
self.graph_name,
current_entity_id,
)
outgoing_neighbors = await self._query(outgoing_query)
incoming_neighbors = await self._query(incoming_query)
neighbors = outgoing_neighbors + incoming_neighbors
# logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})")
for record in neighbors:
if not record.get("b") or not record.get("r"):
continue
b_node = record["b"]
rel = record["r"]
edge_id = str(record["edge_id"])
if (
"properties" not in b_node
or "entity_id" not in b_node["properties"]
):
continue
target_entity_id = b_node["properties"]["entity_id"]
target_internal_id = str(b_node["id"])
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=target_internal_id,
labels=[target_entity_id],
properties=b_node["properties"],
)
# Create edge object
edge = KnowledgeGraphEdge(
id=edge_id,
type=rel["label"],
source=current_node.id,
target=target_internal_id,
properties=rel["properties"],
)
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
sorted_pair = tuple(sorted([current_entity_id, target_entity_id]))
# Add edge (if not already added)
if (
edge_id not in visited_edges
and sorted_pair not in visited_edge_pairs
):
result.edges.append(edge)
visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair)
# If target node not yet visited, add to result and queue
if target_internal_id not in visited_node_ids:
result.nodes.append(target_node)
visited_nodes.add(target_entity_id)
visited_node_ids.add(target_internal_id)
# Add node to queue with incremented depth
queue.append((target_node, current_depth + 1))
# If node limit reached, set truncated flag and exit
if len(visited_node_ids) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: BFS limited to {max_nodes} nodes"
)
break
# If inner loop reached node limit and exited, also exit outer loop
if len(visited_node_ids) >= max_nodes:
break
return result
async def get_knowledge_graph(
self,
node_label: str,
@@ -1836,69 +1989,40 @@ class PGGraphStorage(BaseGraphStorage):
Args:
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
max_nodes: Maxiumu nodes to return, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
# First, count the total number of nodes that would be returned without limit
# Handle wildcard query - get all nodes
if node_label == "*":
# First check total node count to determine if graph should be truncated
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base)
RETURN count(distinct n) AS total_nodes
$$) AS (total_nodes bigint)"""
else:
strip_label = self._normalize_node_id(node_label)
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base {{entity_id: "{strip_label}"}})-[r]-()
RETURN count(r) AS total_nodes
$$) AS (total_nodes bigint)"""
count_result = await self._query(count_query)
total_nodes = count_result[0]["total_nodes"] if count_result else 0
is_truncated = total_nodes > max_nodes
# Now get the actual data with limit
if node_label == "*":
# Get nodes and edges
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base)
OPTIONAL MATCH (node)-[r]->()
RETURN collect(distinct node) AS n, collect(distinct r) AS r
LIMIT {max_nodes}
$$) AS (n agtype, r agtype)"""
else:
strip_label = self._normalize_node_id(node_label)
if total_nodes > 0:
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base {{entity_id: "{strip_label}"}})
OPTIONAL MATCH p = (node)-[*..{max_depth}]-()
RETURN nodes(p) AS n, relationships(p) AS r
LIMIT {max_nodes}
$$) AS (n agtype, r agtype)"""
else:
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base {{entity_id: "{strip_label}"}})
RETURN node AS n
$$) AS (n agtype)"""
results = await self._query(query)
# Process the query results with deduplication by node and edge IDs
# Process query results, deduplicate nodes and edges
nodes_dict = {}
edges_dict = {}
for result in results:
# Handle single node cases
if result.get("n") and isinstance(result["n"], dict):
node_id = str(result["n"]["id"])
if node_id not in nodes_dict:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[result["n"]["properties"]["entity_id"]],
properties=result["n"]["properties"],
)
# Handle node list cases
elif result.get("n") and isinstance(result["n"], list):
if result.get("n") and isinstance(result["n"], list):
for node in result["n"]:
if isinstance(node, dict) and "id" in node:
node_id = str(node["id"])
@@ -1909,19 +2033,7 @@ class PGGraphStorage(BaseGraphStorage):
properties=node["properties"],
)
# Handle single edge cases
if result.get("r") and isinstance(result["r"], dict):
edge_id = str(result["r"]["id"])
if edge_id not in edges_dict:
edges_dict[edge_id] = KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(result["r"]["start_id"]),
target=str(result["r"]["end_id"]),
properties=result["r"]["properties"],
)
# Handle edge list cases
elif result.get("r") and isinstance(result["r"], list):
if result.get("r") and isinstance(result["r"], list):
for edge in result["r"]:
if isinstance(edge, dict) and "id" in edge:
edge_id = str(edge["id"])
@@ -1934,12 +2046,14 @@ class PGGraphStorage(BaseGraphStorage):
properties=edge["properties"],
)
# Construct and return the KnowledgeGraph with deduplicated nodes and edges
kg = KnowledgeGraph(
nodes=list(nodes_dict.values()),
edges=list(edges_dict.values()),
is_truncated=is_truncated,
)
else:
# For single node query, use BFS algorithm
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"