Improve PostgreSQL AGE graph query by write a customized BFS implementation
This commit is contained in:
@@ -1824,6 +1824,159 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
labels.append(result["label"])
|
||||
return labels
|
||||
|
||||
async def _bfs_subgraph(
|
||||
self, node_label: str, max_depth: int, max_nodes: int
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Implements a true breadth-first search algorithm for subgraph retrieval.
|
||||
This method is used as a fallback when the standard Cypher query is too slow
|
||||
or when we need to guarantee BFS ordering.
|
||||
|
||||
Args:
|
||||
node_label: Label of the starting node
|
||||
max_depth: Maximum depth of the subgraph
|
||||
max_nodes: Maximum number of nodes to return
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
result = KnowledgeGraph()
|
||||
visited_nodes = set()
|
||||
visited_node_ids = set()
|
||||
visited_edges = set()
|
||||
visited_edge_pairs = set()
|
||||
|
||||
# Get starting node data
|
||||
label = self._normalize_node_id(node_label)
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:base {entity_id: "%s"})
|
||||
RETURN id(n) as node_id, n
|
||||
$$) AS (node_id bigint, n agtype)""" % (self.graph_name, label)
|
||||
|
||||
node_result = await self._query(query)
|
||||
if not node_result or not node_result[0].get("n"):
|
||||
return result
|
||||
|
||||
# Create initial KnowledgeGraphNode
|
||||
start_node_data = node_result[0]["n"]
|
||||
entity_id = start_node_data["properties"]["entity_id"]
|
||||
internal_id = str(start_node_data["id"])
|
||||
|
||||
start_node = KnowledgeGraphNode(
|
||||
id=internal_id,
|
||||
labels=[entity_id],
|
||||
properties=start_node_data["properties"],
|
||||
)
|
||||
|
||||
# Initialize BFS queue, each element is a tuple of (node, depth)
|
||||
queue = deque([(start_node, 0)])
|
||||
|
||||
visited_nodes.add(entity_id)
|
||||
visited_node_ids.add(internal_id)
|
||||
result.nodes.append(start_node)
|
||||
|
||||
while queue and len(visited_node_ids) < max_nodes:
|
||||
# Dequeue the next node to process from the front of the queue
|
||||
current_node, current_depth = queue.popleft()
|
||||
|
||||
if current_depth >= max_depth:
|
||||
continue
|
||||
|
||||
# Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency
|
||||
current_entity_id = current_node.labels[0]
|
||||
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {entity_id: "%s"})-[r]->(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
|
||||
self.graph_name,
|
||||
current_entity_id,
|
||||
)
|
||||
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:base {entity_id: "%s"})<-[r]-(b)
|
||||
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||
RETURN r, b, edge_id, target_id
|
||||
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
|
||||
self.graph_name,
|
||||
current_entity_id,
|
||||
)
|
||||
|
||||
outgoing_neighbors = await self._query(outgoing_query)
|
||||
incoming_neighbors = await self._query(incoming_query)
|
||||
neighbors = outgoing_neighbors + incoming_neighbors
|
||||
|
||||
# logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})")
|
||||
|
||||
for record in neighbors:
|
||||
if not record.get("b") or not record.get("r"):
|
||||
continue
|
||||
|
||||
b_node = record["b"]
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
|
||||
if (
|
||||
"properties" not in b_node
|
||||
or "entity_id" not in b_node["properties"]
|
||||
):
|
||||
continue
|
||||
|
||||
target_entity_id = b_node["properties"]["entity_id"]
|
||||
target_internal_id = str(b_node["id"])
|
||||
|
||||
# Create KnowledgeGraphNode for target
|
||||
target_node = KnowledgeGraphNode(
|
||||
id=target_internal_id,
|
||||
labels=[target_entity_id],
|
||||
properties=b_node["properties"],
|
||||
)
|
||||
|
||||
# Create edge object
|
||||
edge = KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=rel["label"],
|
||||
source=current_node.id,
|
||||
target=target_internal_id,
|
||||
properties=rel["properties"],
|
||||
)
|
||||
|
||||
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
||||
sorted_pair = tuple(sorted([current_entity_id, target_entity_id]))
|
||||
|
||||
# Add edge (if not already added)
|
||||
if (
|
||||
edge_id not in visited_edges
|
||||
and sorted_pair not in visited_edge_pairs
|
||||
):
|
||||
result.edges.append(edge)
|
||||
visited_edges.add(edge_id)
|
||||
visited_edge_pairs.add(sorted_pair)
|
||||
|
||||
# If target node not yet visited, add to result and queue
|
||||
if target_internal_id not in visited_node_ids:
|
||||
result.nodes.append(target_node)
|
||||
visited_nodes.add(target_entity_id)
|
||||
visited_node_ids.add(target_internal_id)
|
||||
|
||||
# Add node to queue with incremented depth
|
||||
queue.append((target_node, current_depth + 1))
|
||||
|
||||
# If node limit reached, set truncated flag and exit
|
||||
if len(visited_node_ids) >= max_nodes:
|
||||
result.is_truncated = True
|
||||
logger.info(
|
||||
f"Graph truncated: BFS limited to {max_nodes} nodes"
|
||||
)
|
||||
break
|
||||
|
||||
# If inner loop reached node limit and exited, also exit outer loop
|
||||
if len(visited_node_ids) >= max_nodes:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self,
|
||||
node_label: str,
|
||||
@@ -1836,69 +1989,40 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Args:
|
||||
node_label: Label of the starting node, * means all nodes
|
||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||
max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
|
||||
max_nodes: Maxiumu nodes to return, Defaults to 1000
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
# First, count the total number of nodes that would be returned without limit
|
||||
# Handle wildcard query - get all nodes
|
||||
if node_label == "*":
|
||||
# First check total node count to determine if graph should be truncated
|
||||
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base)
|
||||
RETURN count(distinct n) AS total_nodes
|
||||
$$) AS (total_nodes bigint)"""
|
||||
else:
|
||||
strip_label = self._normalize_node_id(node_label)
|
||||
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base {{entity_id: "{strip_label}"}})-[r]-()
|
||||
RETURN count(r) AS total_nodes
|
||||
$$) AS (total_nodes bigint)"""
|
||||
|
||||
count_result = await self._query(count_query)
|
||||
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
||||
is_truncated = total_nodes > max_nodes
|
||||
|
||||
# Now get the actual data with limit
|
||||
if node_label == "*":
|
||||
# Get nodes and edges
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (node:base)
|
||||
OPTIONAL MATCH (node)-[r]->()
|
||||
RETURN collect(distinct node) AS n, collect(distinct r) AS r
|
||||
LIMIT {max_nodes}
|
||||
$$) AS (n agtype, r agtype)"""
|
||||
else:
|
||||
strip_label = self._normalize_node_id(node_label)
|
||||
if total_nodes > 0:
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (node:base {{entity_id: "{strip_label}"}})
|
||||
OPTIONAL MATCH p = (node)-[*..{max_depth}]-()
|
||||
RETURN nodes(p) AS n, relationships(p) AS r
|
||||
LIMIT {max_nodes}
|
||||
$$) AS (n agtype, r agtype)"""
|
||||
else:
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (node:base {{entity_id: "{strip_label}"}})
|
||||
RETURN node AS n
|
||||
$$) AS (n agtype)"""
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
# Process the query results with deduplication by node and edge IDs
|
||||
# Process query results, deduplicate nodes and edges
|
||||
nodes_dict = {}
|
||||
edges_dict = {}
|
||||
|
||||
for result in results:
|
||||
# Handle single node cases
|
||||
if result.get("n") and isinstance(result["n"], dict):
|
||||
node_id = str(result["n"]["id"])
|
||||
if node_id not in nodes_dict:
|
||||
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[result["n"]["properties"]["entity_id"]],
|
||||
properties=result["n"]["properties"],
|
||||
)
|
||||
# Handle node list cases
|
||||
elif result.get("n") and isinstance(result["n"], list):
|
||||
if result.get("n") and isinstance(result["n"], list):
|
||||
for node in result["n"]:
|
||||
if isinstance(node, dict) and "id" in node:
|
||||
node_id = str(node["id"])
|
||||
@@ -1909,19 +2033,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
properties=node["properties"],
|
||||
)
|
||||
|
||||
# Handle single edge cases
|
||||
if result.get("r") and isinstance(result["r"], dict):
|
||||
edge_id = str(result["r"]["id"])
|
||||
if edge_id not in edges_dict:
|
||||
edges_dict[edge_id] = KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type="DIRECTED",
|
||||
source=str(result["r"]["start_id"]),
|
||||
target=str(result["r"]["end_id"]),
|
||||
properties=result["r"]["properties"],
|
||||
)
|
||||
# Handle edge list cases
|
||||
elif result.get("r") and isinstance(result["r"], list):
|
||||
if result.get("r") and isinstance(result["r"], list):
|
||||
for edge in result["r"]:
|
||||
if isinstance(edge, dict) and "id" in edge:
|
||||
edge_id = str(edge["id"])
|
||||
@@ -1934,12 +2046,14 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
properties=edge["properties"],
|
||||
)
|
||||
|
||||
# Construct and return the KnowledgeGraph with deduplicated nodes and edges
|
||||
kg = KnowledgeGraph(
|
||||
nodes=list(nodes_dict.values()),
|
||||
edges=list(edges_dict.values()),
|
||||
is_truncated=is_truncated,
|
||||
)
|
||||
else:
|
||||
# For single node query, use BFS algorithm
|
||||
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
||||
|
Reference in New Issue
Block a user