Improve PostgreSQL AGE graph query by write a customized BFS implementation
This commit is contained in:
@@ -1824,6 +1824,159 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
labels.append(result["label"])
|
labels.append(result["label"])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
async def _bfs_subgraph(
|
||||||
|
self, node_label: str, max_depth: int, max_nodes: int
|
||||||
|
) -> KnowledgeGraph:
|
||||||
|
"""
|
||||||
|
Implements a true breadth-first search algorithm for subgraph retrieval.
|
||||||
|
This method is used as a fallback when the standard Cypher query is too slow
|
||||||
|
or when we need to guarantee BFS ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_label: Label of the starting node
|
||||||
|
max_depth: Maximum depth of the subgraph
|
||||||
|
max_nodes: Maximum number of nodes to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KnowledgeGraph object containing nodes and edges
|
||||||
|
"""
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
visited_nodes = set()
|
||||||
|
visited_node_ids = set()
|
||||||
|
visited_edges = set()
|
||||||
|
visited_edge_pairs = set()
|
||||||
|
|
||||||
|
# Get starting node data
|
||||||
|
label = self._normalize_node_id(node_label)
|
||||||
|
query = """SELECT * FROM cypher('%s', $$
|
||||||
|
MATCH (n:base {entity_id: "%s"})
|
||||||
|
RETURN id(n) as node_id, n
|
||||||
|
$$) AS (node_id bigint, n agtype)""" % (self.graph_name, label)
|
||||||
|
|
||||||
|
node_result = await self._query(query)
|
||||||
|
if not node_result or not node_result[0].get("n"):
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Create initial KnowledgeGraphNode
|
||||||
|
start_node_data = node_result[0]["n"]
|
||||||
|
entity_id = start_node_data["properties"]["entity_id"]
|
||||||
|
internal_id = str(start_node_data["id"])
|
||||||
|
|
||||||
|
start_node = KnowledgeGraphNode(
|
||||||
|
id=internal_id,
|
||||||
|
labels=[entity_id],
|
||||||
|
properties=start_node_data["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize BFS queue, each element is a tuple of (node, depth)
|
||||||
|
queue = deque([(start_node, 0)])
|
||||||
|
|
||||||
|
visited_nodes.add(entity_id)
|
||||||
|
visited_node_ids.add(internal_id)
|
||||||
|
result.nodes.append(start_node)
|
||||||
|
|
||||||
|
while queue and len(visited_node_ids) < max_nodes:
|
||||||
|
# Dequeue the next node to process from the front of the queue
|
||||||
|
current_node, current_depth = queue.popleft()
|
||||||
|
|
||||||
|
if current_depth >= max_depth:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get all edges and target nodes for the current node - query outgoing and incoming edges separately for efficiency
|
||||||
|
current_entity_id = current_node.labels[0]
|
||||||
|
outgoing_query = """SELECT * FROM cypher('%s', $$
|
||||||
|
MATCH (a:base {entity_id: "%s"})-[r]->(b)
|
||||||
|
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||||
|
RETURN r, b, edge_id, target_id
|
||||||
|
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
|
||||||
|
self.graph_name,
|
||||||
|
current_entity_id,
|
||||||
|
)
|
||||||
|
incoming_query = """SELECT * FROM cypher('%s', $$
|
||||||
|
MATCH (a:base {entity_id: "%s"})<-[r]-(b)
|
||||||
|
WITH r, b, id(r) as edge_id, id(b) as target_id
|
||||||
|
RETURN r, b, edge_id, target_id
|
||||||
|
$$) AS (r agtype, b agtype, edge_id bigint, target_id bigint)""" % (
|
||||||
|
self.graph_name,
|
||||||
|
current_entity_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
outgoing_neighbors = await self._query(outgoing_query)
|
||||||
|
incoming_neighbors = await self._query(incoming_query)
|
||||||
|
neighbors = outgoing_neighbors + incoming_neighbors
|
||||||
|
|
||||||
|
# logger.debug(f"Node {current_entity_id} has {len(neighbors)} neighbors (outgoing: {len(outgoing_neighbors)}, incoming: {len(incoming_neighbors)})")
|
||||||
|
|
||||||
|
for record in neighbors:
|
||||||
|
if not record.get("b") or not record.get("r"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
b_node = record["b"]
|
||||||
|
rel = record["r"]
|
||||||
|
edge_id = str(record["edge_id"])
|
||||||
|
|
||||||
|
if (
|
||||||
|
"properties" not in b_node
|
||||||
|
or "entity_id" not in b_node["properties"]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_entity_id = b_node["properties"]["entity_id"]
|
||||||
|
target_internal_id = str(b_node["id"])
|
||||||
|
|
||||||
|
# Create KnowledgeGraphNode for target
|
||||||
|
target_node = KnowledgeGraphNode(
|
||||||
|
id=target_internal_id,
|
||||||
|
labels=[target_entity_id],
|
||||||
|
properties=b_node["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create edge object
|
||||||
|
edge = KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=rel["label"],
|
||||||
|
source=current_node.id,
|
||||||
|
target=target_internal_id,
|
||||||
|
properties=rel["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
||||||
|
sorted_pair = tuple(sorted([current_entity_id, target_entity_id]))
|
||||||
|
|
||||||
|
# Add edge (if not already added)
|
||||||
|
if (
|
||||||
|
edge_id not in visited_edges
|
||||||
|
and sorted_pair not in visited_edge_pairs
|
||||||
|
):
|
||||||
|
result.edges.append(edge)
|
||||||
|
visited_edges.add(edge_id)
|
||||||
|
visited_edge_pairs.add(sorted_pair)
|
||||||
|
|
||||||
|
# If target node not yet visited, add to result and queue
|
||||||
|
if target_internal_id not in visited_node_ids:
|
||||||
|
result.nodes.append(target_node)
|
||||||
|
visited_nodes.add(target_entity_id)
|
||||||
|
visited_node_ids.add(target_internal_id)
|
||||||
|
|
||||||
|
# Add node to queue with incremented depth
|
||||||
|
queue.append((target_node, current_depth + 1))
|
||||||
|
|
||||||
|
# If node limit reached, set truncated flag and exit
|
||||||
|
if len(visited_node_ids) >= max_nodes:
|
||||||
|
result.is_truncated = True
|
||||||
|
logger.info(
|
||||||
|
f"Graph truncated: BFS limited to {max_nodes} nodes"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# If inner loop reached node limit and exited, also exit outer loop
|
||||||
|
if len(visited_node_ids) >= max_nodes:
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self,
|
self,
|
||||||
node_label: str,
|
node_label: str,
|
||||||
@@ -1836,110 +1989,71 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
Args:
|
Args:
|
||||||
node_label: Label of the starting node, * means all nodes
|
node_label: Label of the starting node, * means all nodes
|
||||||
max_depth: Maximum depth of the subgraph, Defaults to 3
|
max_depth: Maximum depth of the subgraph, Defaults to 3
|
||||||
max_nodes: Maxiumu nodes to return, Defaults to 1000 (not BFS nor DFS garanteed)
|
max_nodes: Maxiumu nodes to return, Defaults to 1000
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||||
indicating whether the graph was truncated due to max_nodes limit
|
indicating whether the graph was truncated due to max_nodes limit
|
||||||
"""
|
"""
|
||||||
# First, count the total number of nodes that would be returned without limit
|
# Handle wildcard query - get all nodes
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
|
# First check total node count to determine if graph should be truncated
|
||||||
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (n:base)
|
MATCH (n:base)
|
||||||
RETURN count(distinct n) AS total_nodes
|
RETURN count(distinct n) AS total_nodes
|
||||||
$$) AS (total_nodes bigint)"""
|
$$) AS (total_nodes bigint)"""
|
||||||
else:
|
|
||||||
strip_label = self._normalize_node_id(node_label)
|
|
||||||
count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
|
||||||
MATCH (n:base {{entity_id: "{strip_label}"}})-[r]-()
|
|
||||||
RETURN count(r) AS total_nodes
|
|
||||||
$$) AS (total_nodes bigint)"""
|
|
||||||
|
|
||||||
count_result = await self._query(count_query)
|
count_result = await self._query(count_query)
|
||||||
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
||||||
is_truncated = total_nodes > max_nodes
|
is_truncated = total_nodes > max_nodes
|
||||||
|
|
||||||
# Now get the actual data with limit
|
# Get nodes and edges
|
||||||
if node_label == "*":
|
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (node:base)
|
MATCH (node:base)
|
||||||
OPTIONAL MATCH (node)-[r]->()
|
OPTIONAL MATCH (node)-[r]->()
|
||||||
RETURN collect(distinct node) AS n, collect(distinct r) AS r
|
RETURN collect(distinct node) AS n, collect(distinct r) AS r
|
||||||
LIMIT {max_nodes}
|
LIMIT {max_nodes}
|
||||||
$$) AS (n agtype, r agtype)"""
|
$$) AS (n agtype, r agtype)"""
|
||||||
|
|
||||||
|
results = await self._query(query)
|
||||||
|
|
||||||
|
# Process query results, deduplicate nodes and edges
|
||||||
|
nodes_dict = {}
|
||||||
|
edges_dict = {}
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
if result.get("n") and isinstance(result["n"], list):
|
||||||
|
for node in result["n"]:
|
||||||
|
if isinstance(node, dict) and "id" in node:
|
||||||
|
node_id = str(node["id"])
|
||||||
|
if node_id not in nodes_dict and "properties" in node:
|
||||||
|
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node["properties"]["entity_id"]],
|
||||||
|
properties=node["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("r") and isinstance(result["r"], list):
|
||||||
|
for edge in result["r"]:
|
||||||
|
if isinstance(edge, dict) and "id" in edge:
|
||||||
|
edge_id = str(edge["id"])
|
||||||
|
if edge_id not in edges_dict:
|
||||||
|
edges_dict[edge_id] = KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type="DIRECTED",
|
||||||
|
source=str(edge["start_id"]),
|
||||||
|
target=str(edge["end_id"]),
|
||||||
|
properties=edge["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
|
kg = KnowledgeGraph(
|
||||||
|
nodes=list(nodes_dict.values()),
|
||||||
|
edges=list(edges_dict.values()),
|
||||||
|
is_truncated=is_truncated,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
strip_label = self._normalize_node_id(node_label)
|
# For single node query, use BFS algorithm
|
||||||
if total_nodes > 0:
|
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
|
||||||
MATCH (node:base {{entity_id: "{strip_label}"}})
|
|
||||||
OPTIONAL MATCH p = (node)-[*..{max_depth}]-()
|
|
||||||
RETURN nodes(p) AS n, relationships(p) AS r
|
|
||||||
LIMIT {max_nodes}
|
|
||||||
$$) AS (n agtype, r agtype)"""
|
|
||||||
else:
|
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
|
||||||
MATCH (node:base {{entity_id: "{strip_label}"}})
|
|
||||||
RETURN node AS n
|
|
||||||
$$) AS (n agtype)"""
|
|
||||||
|
|
||||||
results = await self._query(query)
|
|
||||||
|
|
||||||
# Process the query results with deduplication by node and edge IDs
|
|
||||||
nodes_dict = {}
|
|
||||||
edges_dict = {}
|
|
||||||
for result in results:
|
|
||||||
# Handle single node cases
|
|
||||||
if result.get("n") and isinstance(result["n"], dict):
|
|
||||||
node_id = str(result["n"]["id"])
|
|
||||||
if node_id not in nodes_dict:
|
|
||||||
nodes_dict[node_id] = KnowledgeGraphNode(
|
|
||||||
id=node_id,
|
|
||||||
labels=[result["n"]["properties"]["entity_id"]],
|
|
||||||
properties=result["n"]["properties"],
|
|
||||||
)
|
|
||||||
# Handle node list cases
|
|
||||||
elif result.get("n") and isinstance(result["n"], list):
|
|
||||||
for node in result["n"]:
|
|
||||||
if isinstance(node, dict) and "id" in node:
|
|
||||||
node_id = str(node["id"])
|
|
||||||
if node_id not in nodes_dict and "properties" in node:
|
|
||||||
nodes_dict[node_id] = KnowledgeGraphNode(
|
|
||||||
id=node_id,
|
|
||||||
labels=[node["properties"]["entity_id"]],
|
|
||||||
properties=node["properties"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle single edge cases
|
|
||||||
if result.get("r") and isinstance(result["r"], dict):
|
|
||||||
edge_id = str(result["r"]["id"])
|
|
||||||
if edge_id not in edges_dict:
|
|
||||||
edges_dict[edge_id] = KnowledgeGraphEdge(
|
|
||||||
id=edge_id,
|
|
||||||
type="DIRECTED",
|
|
||||||
source=str(result["r"]["start_id"]),
|
|
||||||
target=str(result["r"]["end_id"]),
|
|
||||||
properties=result["r"]["properties"],
|
|
||||||
)
|
|
||||||
# Handle edge list cases
|
|
||||||
elif result.get("r") and isinstance(result["r"], list):
|
|
||||||
for edge in result["r"]:
|
|
||||||
if isinstance(edge, dict) and "id" in edge:
|
|
||||||
edge_id = str(edge["id"])
|
|
||||||
if edge_id not in edges_dict:
|
|
||||||
edges_dict[edge_id] = KnowledgeGraphEdge(
|
|
||||||
id=edge_id,
|
|
||||||
type="DIRECTED",
|
|
||||||
source=str(edge["start_id"]),
|
|
||||||
target=str(edge["end_id"]),
|
|
||||||
properties=edge["properties"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Construct and return the KnowledgeGraph with deduplicated nodes and edges
|
|
||||||
kg = KnowledgeGraph(
|
|
||||||
nodes=list(nodes_dict.values()),
|
|
||||||
edges=list(edges_dict.values()),
|
|
||||||
is_truncated=is_truncated,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
||||||
|
Reference in New Issue
Block a user