Fix max_nodes not working in graph queries when using the '*' wildcard.

This commit is contained in:
yangdx
2025-04-25 21:25:37 +08:00
parent e3c87dd6bd
commit 5eb019a7fc

View File

@@ -1907,7 +1907,9 @@ class PGGraphStorage(BaseGraphStorage):
# Prepare node IDs list
node_ids = [node.labels[0] for node in current_level_nodes]
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
formatted_ids = ", ".join(
[f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
)
# Construct batch query for outgoing edges
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
@@ -2051,6 +2053,9 @@ class PGGraphStorage(BaseGraphStorage):
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
# 初始化 kg 变量,确保在所有情况下都有定义
kg = KnowledgeGraph()
# Handle wildcard query - get all nodes
if node_label == "*":
# First check total node count to determine if graph should be truncated
@@ -2063,40 +2068,67 @@ class PGGraphStorage(BaseGraphStorage):
total_nodes = count_result[0]["total_nodes"] if count_result else 0
is_truncated = total_nodes > max_nodes
# Get nodes and edges
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base)
OPTIONAL MATCH (node)-[r]->()
RETURN collect(distinct node) AS n, collect(distinct r) AS r
LIMIT {max_nodes}
$$) AS (n agtype, r agtype)"""
# Get max_nodes with highest degrees
query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base)
OPTIONAL MATCH (n)-[r]->()
RETURN id(n) as node_id, count(r) as degree
$$) AS (node_id BIGINT, degree BIGINT)
ORDER BY degree DESC
LIMIT {max_nodes}"""
node_results = await self._query(query_nodes)
node_ids = [str(result["node_id"]) for result in node_results]
logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}")
if node_ids:
formatted_ids = ", ".join(node_ids)
# Construct batch query for subgraph within max_nodes
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{formatted_ids}] AS node_ids
MATCH (a)
WHERE id(a) IN node_ids
OPTIONAL MATCH (a)-[r]->(b)
WHERE id(b) IN node_ids
RETURN a, r, b
$$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
results = await self._query(query)
# Process query results, deduplicate nodes and edges
nodes_dict = {}
edges_dict = {}
for result in results:
if result.get("n") and isinstance(result["n"], list):
for node in result["n"]:
if isinstance(node, dict) and "id" in node:
node_id = str(node["id"])
if node_id not in nodes_dict and "properties" in node:
# 处理节点 a
if result.get("a") and isinstance(result["a"], dict):
node_a = result["a"]
node_id = str(node_a["id"])
if node_id not in nodes_dict and "properties" in node_a:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[node["properties"]["entity_id"]],
properties=node["properties"],
labels=[node_a["properties"]["entity_id"]],
properties=node_a["properties"],
)
if result.get("r") and isinstance(result["r"], list):
for edge in result["r"]:
if isinstance(edge, dict) and "id" in edge:
# 处理节点 b
if result.get("b") and isinstance(result["b"], dict):
node_b = result["b"]
node_id = str(node_b["id"])
if node_id not in nodes_dict and "properties" in node_b:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[node_b["properties"]["entity_id"]],
properties=node_b["properties"],
)
# 处理边 r
if result.get("r") and isinstance(result["r"], dict):
edge = result["r"]
edge_id = str(edge["id"])
if edge_id not in edges_dict:
edges_dict[edge_id] = KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
type=edge["label"],
source=str(edge["start_id"]),
target=str(edge["end_id"]),
properties=edge["properties"],
@@ -2114,6 +2146,13 @@ class PGGraphStorage(BaseGraphStorage):
logger.info(
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
else:
# 非通配符查询,使用 BFS 算法
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
return kg
async def drop(self) -> dict[str, str]: