Fix max_nodes not working in graph queries when using the '*' wildcard.

This commit is contained in:
yangdx
2025-04-25 21:25:37 +08:00
parent e3c87dd6bd
commit 5eb019a7fc

View File

@@ -1907,7 +1907,9 @@ class PGGraphStorage(BaseGraphStorage):
# Prepare node IDs list # Prepare node IDs list
node_ids = [node.labels[0] for node in current_level_nodes] node_ids = [node.labels[0] for node in current_level_nodes]
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]) formatted_ids = ", ".join(
[f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
)
# Construct batch query for outgoing edges # Construct batch query for outgoing edges
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
@@ -2051,6 +2053,9 @@ class PGGraphStorage(BaseGraphStorage):
KnowledgeGraph object containing nodes and edges, with an is_truncated flag KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit indicating whether the graph was truncated due to max_nodes limit
""" """
# 初始化 kg 变量,确保在所有情况下都有定义
kg = KnowledgeGraph()
# Handle wildcard query - get all nodes # Handle wildcard query - get all nodes
if node_label == "*": if node_label == "*":
# First check total node count to determine if graph should be truncated # First check total node count to determine if graph should be truncated
@@ -2063,40 +2068,67 @@ class PGGraphStorage(BaseGraphStorage):
total_nodes = count_result[0]["total_nodes"] if count_result else 0 total_nodes = count_result[0]["total_nodes"] if count_result else 0
is_truncated = total_nodes > max_nodes is_truncated = total_nodes > max_nodes
# Get nodes and edges # Get max_nodes with highest degrees
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base) MATCH (n:base)
OPTIONAL MATCH (node)-[r]->() OPTIONAL MATCH (n)-[r]->()
RETURN collect(distinct node) AS n, collect(distinct r) AS r RETURN id(n) as node_id, count(r) as degree
LIMIT {max_nodes} $$) AS (node_id BIGINT, degree BIGINT)
$$) AS (n agtype, r agtype)""" ORDER BY degree DESC
LIMIT {max_nodes}"""
node_results = await self._query(query_nodes)
node_ids = [str(result["node_id"]) for result in node_results]
logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}")
if node_ids:
formatted_ids = ", ".join(node_ids)
# Construct batch query for subgraph within max_nodes
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{formatted_ids}] AS node_ids
MATCH (a)
WHERE id(a) IN node_ids
OPTIONAL MATCH (a)-[r]->(b)
WHERE id(b) IN node_ids
RETURN a, r, b
$$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
results = await self._query(query) results = await self._query(query)
# Process query results, deduplicate nodes and edges # Process query results, deduplicate nodes and edges
nodes_dict = {} nodes_dict = {}
edges_dict = {} edges_dict = {}
for result in results: for result in results:
if result.get("n") and isinstance(result["n"], list): # 处理节点 a
for node in result["n"]: if result.get("a") and isinstance(result["a"], dict):
if isinstance(node, dict) and "id" in node: node_a = result["a"]
node_id = str(node["id"]) node_id = str(node_a["id"])
if node_id not in nodes_dict and "properties" in node: if node_id not in nodes_dict and "properties" in node_a:
nodes_dict[node_id] = KnowledgeGraphNode( nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id, id=node_id,
labels=[node["properties"]["entity_id"]], labels=[node_a["properties"]["entity_id"]],
properties=node["properties"], properties=node_a["properties"],
) )
if result.get("r") and isinstance(result["r"], list): # 处理节点 b
for edge in result["r"]: if result.get("b") and isinstance(result["b"], dict):
if isinstance(edge, dict) and "id" in edge: node_b = result["b"]
node_id = str(node_b["id"])
if node_id not in nodes_dict and "properties" in node_b:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[node_b["properties"]["entity_id"]],
properties=node_b["properties"],
)
# 处理边 r
if result.get("r") and isinstance(result["r"], dict):
edge = result["r"]
edge_id = str(edge["id"]) edge_id = str(edge["id"])
if edge_id not in edges_dict: if edge_id not in edges_dict:
edges_dict[edge_id] = KnowledgeGraphEdge( edges_dict[edge_id] = KnowledgeGraphEdge(
id=edge_id, id=edge_id,
type="DIRECTED", type=edge["label"],
source=str(edge["start_id"]), source=str(edge["start_id"]),
target=str(edge["end_id"]), target=str(edge["end_id"]),
properties=edge["properties"], properties=edge["properties"],
@@ -2114,6 +2146,13 @@ class PGGraphStorage(BaseGraphStorage):
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
) )
else:
# 非通配符查询,使用 BFS 算法
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
return kg return kg
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]: