Fix max_nodes not working in graph queries when using the '*' wildcard.

This commit is contained in:
yangdx
2025-04-25 21:25:37 +08:00
parent e3c87dd6bd
commit 5eb019a7fc

View File

@@ -1886,85 +1886,87 @@ class PGGraphStorage(BaseGraphStorage):
# Get all nodes at the current depth
current_level_nodes = []
current_depth = None
# Determine current depth
if queue:
current_depth = queue[0][1]
# Extract all nodes at current depth from the queue
while queue and queue[0][1] == current_depth:
node, depth = queue.popleft()
if depth > max_depth:
continue
current_level_nodes.append(node)
if not current_level_nodes:
continue
# Check depth limit
if current_depth > max_depth:
continue
# Prepare node IDs list
node_ids = [node.labels[0] for node in current_level_nodes]
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
formatted_ids = ", ".join(
[f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
)
# Construct batch query for outgoing edges
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id
MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)-[r]->(neighbor:base)
RETURN node_id AS current_id,
id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id,
id(r) AS edge_id,
r,
RETURN node_id AS current_id,
id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id,
id(r) AS edge_id,
r,
neighbor,
true AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Construct batch query for incoming edges
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id
MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
RETURN node_id AS current_id,
id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id,
id(r) AS edge_id,
r,
RETURN node_id AS current_id,
id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id,
id(r) AS edge_id,
r,
neighbor,
false AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Execute queries
outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query)
# Combine results
neighbors = outgoing_results + incoming_results
# Create mapping from node ID to node object
node_map = {node.labels[0]: node for node in current_level_nodes}
# Process all results in a single loop
for record in neighbors:
if not record.get("neighbor") or not record.get("r"):
continue
# Get current node information
current_entity_id = record["current_id"]
current_node = node_map[current_entity_id]
# Get neighbor node information
neighbor_entity_id = record["neighbor_id"]
neighbor_internal_id = str(record["neighbor_internal_id"])
is_outgoing = record["is_outgoing"]
# Determine edge direction
if is_outgoing:
source_id = current_node.id
@@ -1972,25 +1974,25 @@ class PGGraphStorage(BaseGraphStorage):
else:
source_id = neighbor_internal_id
target_id = current_node.id
if not neighbor_entity_id:
continue
# Get edge and node information
b_node = record["neighbor"]
rel = record["r"]
edge_id = str(record["edge_id"])
# Create neighbor node object
neighbor_node = KnowledgeGraphNode(
id=neighbor_internal_id,
labels=[neighbor_entity_id],
properties=b_node["properties"],
)
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
# Create edge object
edge = KnowledgeGraphEdge(
id=edge_id,
@@ -1999,7 +2001,7 @@ class PGGraphStorage(BaseGraphStorage):
target=target_id,
properties=rel["properties"],
)
if neighbor_internal_id in visited_node_ids:
# Add backward edge if neighbor node is already visited
if (
@@ -2015,10 +2017,10 @@ class PGGraphStorage(BaseGraphStorage):
result.nodes.append(neighbor_node)
visited_nodes.add(neighbor_entity_id)
visited_node_ids.add(neighbor_internal_id)
# Add node to queue with incremented depth
queue.append((neighbor_node, current_depth + 1))
# Add forward edge
if (
edge_id not in visited_edges
@@ -2051,6 +2053,9 @@ class PGGraphStorage(BaseGraphStorage):
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
# 初始化 kg 变量,确保在所有情况下都有定义
kg = KnowledgeGraph()
# Handle wildcard query - get all nodes
if node_label == "*":
# First check total node count to determine if graph should be truncated
@@ -2063,57 +2068,91 @@ class PGGraphStorage(BaseGraphStorage):
total_nodes = count_result[0]["total_nodes"] if count_result else 0
is_truncated = total_nodes > max_nodes
# Get nodes and edges
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base)
OPTIONAL MATCH (node)-[r]->()
RETURN collect(distinct node) AS n, collect(distinct r) AS r
LIMIT {max_nodes}
$$) AS (n agtype, r agtype)"""
# Get max_nodes with highest degrees
query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (n:base)
OPTIONAL MATCH (n)-[r]->()
RETURN id(n) as node_id, count(r) as degree
$$) AS (node_id BIGINT, degree BIGINT)
ORDER BY degree DESC
LIMIT {max_nodes}"""
node_results = await self._query(query_nodes)
results = await self._query(query)
node_ids = [str(result["node_id"]) for result in node_results]
# Process query results, deduplicate nodes and edges
nodes_dict = {}
edges_dict = {}
logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}")
for result in results:
if result.get("n") and isinstance(result["n"], list):
for node in result["n"]:
if isinstance(node, dict) and "id" in node:
node_id = str(node["id"])
if node_id not in nodes_dict and "properties" in node:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[node["properties"]["entity_id"]],
properties=node["properties"],
)
if node_ids:
formatted_ids = ", ".join(node_ids)
# Construct batch query for subgraph within max_nodes
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
WITH [{formatted_ids}] AS node_ids
MATCH (a)
WHERE id(a) IN node_ids
OPTIONAL MATCH (a)-[r]->(b)
WHERE id(b) IN node_ids
RETURN a, r, b
$$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
results = await self._query(query)
if result.get("r") and isinstance(result["r"], list):
for edge in result["r"]:
if isinstance(edge, dict) and "id" in edge:
edge_id = str(edge["id"])
if edge_id not in edges_dict:
edges_dict[edge_id] = KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(edge["start_id"]),
target=str(edge["end_id"]),
properties=edge["properties"],
)
# Process query results, deduplicate nodes and edges
nodes_dict = {}
edges_dict = {}
for result in results:
# 处理节点 a
if result.get("a") and isinstance(result["a"], dict):
node_a = result["a"]
node_id = str(node_a["id"])
if node_id not in nodes_dict and "properties" in node_a:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[node_a["properties"]["entity_id"]],
properties=node_a["properties"],
)
kg = KnowledgeGraph(
nodes=list(nodes_dict.values()),
edges=list(edges_dict.values()),
is_truncated=is_truncated,
# 处理节点 b
if result.get("b") and isinstance(result["b"], dict):
node_b = result["b"]
node_id = str(node_b["id"])
if node_id not in nodes_dict and "properties" in node_b:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[node_b["properties"]["entity_id"]],
properties=node_b["properties"],
)
# 处理边 r
if result.get("r") and isinstance(result["r"], dict):
edge = result["r"]
edge_id = str(edge["id"])
if edge_id not in edges_dict:
edges_dict[edge_id] = KnowledgeGraphEdge(
id=edge_id,
type=edge["label"],
source=str(edge["start_id"]),
target=str(edge["end_id"]),
properties=edge["properties"],
)
kg = KnowledgeGraph(
nodes=list(nodes_dict.values()),
edges=list(edges_dict.values()),
is_truncated=is_truncated,
)
else:
# For single node query, use BFS algorithm
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
else:
# For single node query, use BFS algorithm
# 非通配符查询,使用 BFS 算法
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
logger.info(
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
return kg
async def drop(self) -> dict[str, str]: