Fix max_nodes not working in graph queries when using the '*' wildcard.

This commit is contained in:
yangdx
2025-04-25 21:25:37 +08:00
parent e3c87dd6bd
commit 5eb019a7fc

View File

@@ -1886,85 +1886,87 @@ class PGGraphStorage(BaseGraphStorage):
# Get all nodes at the current depth # Get all nodes at the current depth
current_level_nodes = [] current_level_nodes = []
current_depth = None current_depth = None
# Determine current depth # Determine current depth
if queue: if queue:
current_depth = queue[0][1] current_depth = queue[0][1]
# Extract all nodes at current depth from the queue # Extract all nodes at current depth from the queue
while queue and queue[0][1] == current_depth: while queue and queue[0][1] == current_depth:
node, depth = queue.popleft() node, depth = queue.popleft()
if depth > max_depth: if depth > max_depth:
continue continue
current_level_nodes.append(node) current_level_nodes.append(node)
if not current_level_nodes: if not current_level_nodes:
continue continue
# Check depth limit # Check depth limit
if current_depth > max_depth: if current_depth > max_depth:
continue continue
# Prepare node IDs list # Prepare node IDs list
node_ids = [node.labels[0] for node in current_level_nodes] node_ids = [node.labels[0] for node in current_level_nodes]
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]) formatted_ids = ", ".join(
[f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
)
# Construct batch query for outgoing edges # Construct batch query for outgoing edges
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id UNWIND [{formatted_ids}] AS node_id
MATCH (n:base {{entity_id: node_id}}) MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)-[r]->(neighbor:base) OPTIONAL MATCH (n)-[r]->(neighbor:base)
RETURN node_id AS current_id, RETURN node_id AS current_id,
id(n) AS current_internal_id, id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id, id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id, neighbor.entity_id AS neighbor_id,
id(r) AS edge_id, id(r) AS edge_id,
r, r,
neighbor, neighbor,
true AS is_outgoing true AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Construct batch query for incoming edges # Construct batch query for incoming edges
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
UNWIND [{formatted_ids}] AS node_id UNWIND [{formatted_ids}] AS node_id
MATCH (n:base {{entity_id: node_id}}) MATCH (n:base {{entity_id: node_id}})
OPTIONAL MATCH (n)<-[r]-(neighbor:base) OPTIONAL MATCH (n)<-[r]-(neighbor:base)
RETURN node_id AS current_id, RETURN node_id AS current_id,
id(n) AS current_internal_id, id(n) AS current_internal_id,
id(neighbor) AS neighbor_internal_id, id(neighbor) AS neighbor_internal_id,
neighbor.entity_id AS neighbor_id, neighbor.entity_id AS neighbor_id,
id(r) AS edge_id, id(r) AS edge_id,
r, r,
neighbor, neighbor,
false AS is_outgoing false AS is_outgoing
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
# Execute queries # Execute queries
outgoing_results = await self._query(outgoing_query) outgoing_results = await self._query(outgoing_query)
incoming_results = await self._query(incoming_query) incoming_results = await self._query(incoming_query)
# Combine results # Combine results
neighbors = outgoing_results + incoming_results neighbors = outgoing_results + incoming_results
# Create mapping from node ID to node object # Create mapping from node ID to node object
node_map = {node.labels[0]: node for node in current_level_nodes} node_map = {node.labels[0]: node for node in current_level_nodes}
# Process all results in a single loop # Process all results in a single loop
for record in neighbors: for record in neighbors:
if not record.get("neighbor") or not record.get("r"): if not record.get("neighbor") or not record.get("r"):
continue continue
# Get current node information # Get current node information
current_entity_id = record["current_id"] current_entity_id = record["current_id"]
current_node = node_map[current_entity_id] current_node = node_map[current_entity_id]
# Get neighbor node information # Get neighbor node information
neighbor_entity_id = record["neighbor_id"] neighbor_entity_id = record["neighbor_id"]
neighbor_internal_id = str(record["neighbor_internal_id"]) neighbor_internal_id = str(record["neighbor_internal_id"])
is_outgoing = record["is_outgoing"] is_outgoing = record["is_outgoing"]
# Determine edge direction # Determine edge direction
if is_outgoing: if is_outgoing:
source_id = current_node.id source_id = current_node.id
@@ -1972,25 +1974,25 @@ class PGGraphStorage(BaseGraphStorage):
else: else:
source_id = neighbor_internal_id source_id = neighbor_internal_id
target_id = current_node.id target_id = current_node.id
if not neighbor_entity_id: if not neighbor_entity_id:
continue continue
# Get edge and node information # Get edge and node information
b_node = record["neighbor"] b_node = record["neighbor"]
rel = record["r"] rel = record["r"]
edge_id = str(record["edge_id"]) edge_id = str(record["edge_id"])
# Create neighbor node object # Create neighbor node object
neighbor_node = KnowledgeGraphNode( neighbor_node = KnowledgeGraphNode(
id=neighbor_internal_id, id=neighbor_internal_id,
labels=[neighbor_entity_id], labels=[neighbor_entity_id],
properties=b_node["properties"], properties=b_node["properties"],
) )
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id])) sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
# Create edge object # Create edge object
edge = KnowledgeGraphEdge( edge = KnowledgeGraphEdge(
id=edge_id, id=edge_id,
@@ -1999,7 +2001,7 @@ class PGGraphStorage(BaseGraphStorage):
target=target_id, target=target_id,
properties=rel["properties"], properties=rel["properties"],
) )
if neighbor_internal_id in visited_node_ids: if neighbor_internal_id in visited_node_ids:
# Add backward edge if neighbor node is already visited # Add backward edge if neighbor node is already visited
if ( if (
@@ -2015,10 +2017,10 @@ class PGGraphStorage(BaseGraphStorage):
result.nodes.append(neighbor_node) result.nodes.append(neighbor_node)
visited_nodes.add(neighbor_entity_id) visited_nodes.add(neighbor_entity_id)
visited_node_ids.add(neighbor_internal_id) visited_node_ids.add(neighbor_internal_id)
# Add node to queue with incremented depth # Add node to queue with incremented depth
queue.append((neighbor_node, current_depth + 1)) queue.append((neighbor_node, current_depth + 1))
# Add forward edge # Add forward edge
if ( if (
edge_id not in visited_edges edge_id not in visited_edges
@@ -2051,6 +2053,9 @@ class PGGraphStorage(BaseGraphStorage):
KnowledgeGraph object containing nodes and edges, with an is_truncated flag KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit indicating whether the graph was truncated due to max_nodes limit
""" """
# 初始化 kg 变量,确保在所有情况下都有定义
kg = KnowledgeGraph()
# Handle wildcard query - get all nodes # Handle wildcard query - get all nodes
if node_label == "*": if node_label == "*":
# First check total node count to determine if graph should be truncated # First check total node count to determine if graph should be truncated
@@ -2063,57 +2068,91 @@ class PGGraphStorage(BaseGraphStorage):
total_nodes = count_result[0]["total_nodes"] if count_result else 0 total_nodes = count_result[0]["total_nodes"] if count_result else 0
is_truncated = total_nodes > max_nodes is_truncated = total_nodes > max_nodes
# Get nodes and edges # Get max_nodes with highest degrees
query = f"""SELECT * FROM cypher('{self.graph_name}', $$ query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$
MATCH (node:base) MATCH (n:base)
OPTIONAL MATCH (node)-[r]->() OPTIONAL MATCH (n)-[r]->()
RETURN collect(distinct node) AS n, collect(distinct r) AS r RETURN id(n) as node_id, count(r) as degree
LIMIT {max_nodes} $$) AS (node_id BIGINT, degree BIGINT)
$$) AS (n agtype, r agtype)""" ORDER BY degree DESC
LIMIT {max_nodes}"""
node_results = await self._query(query_nodes)
results = await self._query(query) node_ids = [str(result["node_id"]) for result in node_results]
# Process query results, deduplicate nodes and edges logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}")
nodes_dict = {}
edges_dict = {}
for result in results: if node_ids:
if result.get("n") and isinstance(result["n"], list): formatted_ids = ", ".join(node_ids)
for node in result["n"]: # Construct batch query for subgraph within max_nodes
if isinstance(node, dict) and "id" in node: query = f"""SELECT * FROM cypher('{self.graph_name}', $$
node_id = str(node["id"]) WITH [{formatted_ids}] AS node_ids
if node_id not in nodes_dict and "properties" in node: MATCH (a)
nodes_dict[node_id] = KnowledgeGraphNode( WHERE id(a) IN node_ids
id=node_id, OPTIONAL MATCH (a)-[r]->(b)
labels=[node["properties"]["entity_id"]], WHERE id(b) IN node_ids
properties=node["properties"], RETURN a, r, b
) $$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
results = await self._query(query)
if result.get("r") and isinstance(result["r"], list): # Process query results, deduplicate nodes and edges
for edge in result["r"]: nodes_dict = {}
if isinstance(edge, dict) and "id" in edge: edges_dict = {}
edge_id = str(edge["id"]) for result in results:
if edge_id not in edges_dict: # 处理节点 a
edges_dict[edge_id] = KnowledgeGraphEdge( if result.get("a") and isinstance(result["a"], dict):
id=edge_id, node_a = result["a"]
type="DIRECTED", node_id = str(node_a["id"])
source=str(edge["start_id"]), if node_id not in nodes_dict and "properties" in node_a:
target=str(edge["end_id"]), nodes_dict[node_id] = KnowledgeGraphNode(
properties=edge["properties"], id=node_id,
) labels=[node_a["properties"]["entity_id"]],
properties=node_a["properties"],
)
kg = KnowledgeGraph( # 处理节点 b
nodes=list(nodes_dict.values()), if result.get("b") and isinstance(result["b"], dict):
edges=list(edges_dict.values()), node_b = result["b"]
is_truncated=is_truncated, node_id = str(node_b["id"])
if node_id not in nodes_dict and "properties" in node_b:
nodes_dict[node_id] = KnowledgeGraphNode(
id=node_id,
labels=[node_b["properties"]["entity_id"]],
properties=node_b["properties"],
)
# 处理边 r
if result.get("r") and isinstance(result["r"], dict):
edge = result["r"]
edge_id = str(edge["id"])
if edge_id not in edges_dict:
edges_dict[edge_id] = KnowledgeGraphEdge(
id=edge_id,
type=edge["label"],
source=str(edge["start_id"]),
target=str(edge["end_id"]),
properties=edge["properties"],
)
kg = KnowledgeGraph(
nodes=list(nodes_dict.values()),
edges=list(edges_dict.values()),
is_truncated=is_truncated,
)
else:
# For single node query, use BFS algorithm
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
) )
else: else:
# For single node query, use BFS algorithm # 非通配符查询,使用 BFS 算法
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
logger.info(
f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
logger.info(
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
)
return kg return kg
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]: