Fix max_nodes not working in graph queries when using the '*' wildcard.
This commit is contained in:
@@ -1886,85 +1886,87 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# Get all nodes at the current depth
|
||||
current_level_nodes = []
|
||||
current_depth = None
|
||||
|
||||
|
||||
# Determine current depth
|
||||
if queue:
|
||||
current_depth = queue[0][1]
|
||||
|
||||
|
||||
# Extract all nodes at current depth from the queue
|
||||
while queue and queue[0][1] == current_depth:
|
||||
node, depth = queue.popleft()
|
||||
if depth > max_depth:
|
||||
continue
|
||||
current_level_nodes.append(node)
|
||||
|
||||
|
||||
if not current_level_nodes:
|
||||
continue
|
||||
|
||||
|
||||
# Check depth limit
|
||||
if current_depth > max_depth:
|
||||
continue
|
||||
|
||||
|
||||
# Prepare node IDs list
|
||||
node_ids = [node.labels[0] for node in current_level_nodes]
|
||||
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
|
||||
|
||||
formatted_ids = ", ".join(
|
||||
[f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
# Construct batch query for outgoing edges
|
||||
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND [{formatted_ids}] AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
OPTIONAL MATCH (n)-[r]->(neighbor:base)
|
||||
RETURN node_id AS current_id,
|
||||
id(n) AS current_internal_id,
|
||||
id(neighbor) AS neighbor_internal_id,
|
||||
neighbor.entity_id AS neighbor_id,
|
||||
id(r) AS edge_id,
|
||||
r,
|
||||
RETURN node_id AS current_id,
|
||||
id(n) AS current_internal_id,
|
||||
id(neighbor) AS neighbor_internal_id,
|
||||
neighbor.entity_id AS neighbor_id,
|
||||
id(r) AS edge_id,
|
||||
r,
|
||||
neighbor,
|
||||
true AS is_outgoing
|
||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||
|
||||
|
||||
# Construct batch query for incoming edges
|
||||
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
UNWIND [{formatted_ids}] AS node_id
|
||||
MATCH (n:base {{entity_id: node_id}})
|
||||
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
|
||||
RETURN node_id AS current_id,
|
||||
id(n) AS current_internal_id,
|
||||
id(neighbor) AS neighbor_internal_id,
|
||||
neighbor.entity_id AS neighbor_id,
|
||||
id(r) AS edge_id,
|
||||
r,
|
||||
RETURN node_id AS current_id,
|
||||
id(n) AS current_internal_id,
|
||||
id(neighbor) AS neighbor_internal_id,
|
||||
neighbor.entity_id AS neighbor_id,
|
||||
id(r) AS edge_id,
|
||||
r,
|
||||
neighbor,
|
||||
false AS is_outgoing
|
||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||
|
||||
|
||||
# Execute queries
|
||||
outgoing_results = await self._query(outgoing_query)
|
||||
incoming_results = await self._query(incoming_query)
|
||||
|
||||
|
||||
# Combine results
|
||||
neighbors = outgoing_results + incoming_results
|
||||
|
||||
|
||||
# Create mapping from node ID to node object
|
||||
node_map = {node.labels[0]: node for node in current_level_nodes}
|
||||
|
||||
|
||||
# Process all results in a single loop
|
||||
for record in neighbors:
|
||||
if not record.get("neighbor") or not record.get("r"):
|
||||
continue
|
||||
|
||||
|
||||
# Get current node information
|
||||
current_entity_id = record["current_id"]
|
||||
current_node = node_map[current_entity_id]
|
||||
|
||||
|
||||
# Get neighbor node information
|
||||
neighbor_entity_id = record["neighbor_id"]
|
||||
neighbor_internal_id = str(record["neighbor_internal_id"])
|
||||
is_outgoing = record["is_outgoing"]
|
||||
|
||||
|
||||
# Determine edge direction
|
||||
if is_outgoing:
|
||||
source_id = current_node.id
|
||||
@@ -1972,25 +1974,25 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
source_id = neighbor_internal_id
|
||||
target_id = current_node.id
|
||||
|
||||
|
||||
if not neighbor_entity_id:
|
||||
continue
|
||||
|
||||
|
||||
# Get edge and node information
|
||||
b_node = record["neighbor"]
|
||||
rel = record["r"]
|
||||
edge_id = str(record["edge_id"])
|
||||
|
||||
|
||||
# Create neighbor node object
|
||||
neighbor_node = KnowledgeGraphNode(
|
||||
id=neighbor_internal_id,
|
||||
labels=[neighbor_entity_id],
|
||||
properties=b_node["properties"],
|
||||
)
|
||||
|
||||
|
||||
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
||||
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
|
||||
|
||||
|
||||
# Create edge object
|
||||
edge = KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
@@ -1999,7 +2001,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
target=target_id,
|
||||
properties=rel["properties"],
|
||||
)
|
||||
|
||||
|
||||
if neighbor_internal_id in visited_node_ids:
|
||||
# Add backward edge if neighbor node is already visited
|
||||
if (
|
||||
@@ -2015,10 +2017,10 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
result.nodes.append(neighbor_node)
|
||||
visited_nodes.add(neighbor_entity_id)
|
||||
visited_node_ids.add(neighbor_internal_id)
|
||||
|
||||
|
||||
# Add node to queue with incremented depth
|
||||
queue.append((neighbor_node, current_depth + 1))
|
||||
|
||||
|
||||
# Add forward edge
|
||||
if (
|
||||
edge_id not in visited_edges
|
||||
@@ -2051,6 +2053,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||
indicating whether the graph was truncated due to max_nodes limit
|
||||
"""
|
||||
# 初始化 kg 变量,确保在所有情况下都有定义
|
||||
kg = KnowledgeGraph()
|
||||
|
||||
# Handle wildcard query - get all nodes
|
||||
if node_label == "*":
|
||||
# First check total node count to determine if graph should be truncated
|
||||
@@ -2063,57 +2068,91 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
||||
is_truncated = total_nodes > max_nodes
|
||||
|
||||
# Get nodes and edges
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (node:base)
|
||||
OPTIONAL MATCH (node)-[r]->()
|
||||
RETURN collect(distinct node) AS n, collect(distinct r) AS r
|
||||
LIMIT {max_nodes}
|
||||
$$) AS (n agtype, r agtype)"""
|
||||
# Get max_nodes with highest degrees
|
||||
query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
MATCH (n:base)
|
||||
OPTIONAL MATCH (n)-[r]->()
|
||||
RETURN id(n) as node_id, count(r) as degree
|
||||
$$) AS (node_id BIGINT, degree BIGINT)
|
||||
ORDER BY degree DESC
|
||||
LIMIT {max_nodes}"""
|
||||
node_results = await self._query(query_nodes)
|
||||
|
||||
results = await self._query(query)
|
||||
node_ids = [str(result["node_id"]) for result in node_results]
|
||||
|
||||
# Process query results, deduplicate nodes and edges
|
||||
nodes_dict = {}
|
||||
edges_dict = {}
|
||||
logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}")
|
||||
|
||||
for result in results:
|
||||
if result.get("n") and isinstance(result["n"], list):
|
||||
for node in result["n"]:
|
||||
if isinstance(node, dict) and "id" in node:
|
||||
node_id = str(node["id"])
|
||||
if node_id not in nodes_dict and "properties" in node:
|
||||
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node["properties"]["entity_id"]],
|
||||
properties=node["properties"],
|
||||
)
|
||||
if node_ids:
|
||||
formatted_ids = ", ".join(node_ids)
|
||||
# Construct batch query for subgraph within max_nodes
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
WITH [{formatted_ids}] AS node_ids
|
||||
MATCH (a)
|
||||
WHERE id(a) IN node_ids
|
||||
OPTIONAL MATCH (a)-[r]->(b)
|
||||
WHERE id(b) IN node_ids
|
||||
RETURN a, r, b
|
||||
$$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
|
||||
results = await self._query(query)
|
||||
|
||||
if result.get("r") and isinstance(result["r"], list):
|
||||
for edge in result["r"]:
|
||||
if isinstance(edge, dict) and "id" in edge:
|
||||
edge_id = str(edge["id"])
|
||||
if edge_id not in edges_dict:
|
||||
edges_dict[edge_id] = KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type="DIRECTED",
|
||||
source=str(edge["start_id"]),
|
||||
target=str(edge["end_id"]),
|
||||
properties=edge["properties"],
|
||||
)
|
||||
# Process query results, deduplicate nodes and edges
|
||||
nodes_dict = {}
|
||||
edges_dict = {}
|
||||
for result in results:
|
||||
# 处理节点 a
|
||||
if result.get("a") and isinstance(result["a"], dict):
|
||||
node_a = result["a"]
|
||||
node_id = str(node_a["id"])
|
||||
if node_id not in nodes_dict and "properties" in node_a:
|
||||
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_a["properties"]["entity_id"]],
|
||||
properties=node_a["properties"],
|
||||
)
|
||||
|
||||
kg = KnowledgeGraph(
|
||||
nodes=list(nodes_dict.values()),
|
||||
edges=list(edges_dict.values()),
|
||||
is_truncated=is_truncated,
|
||||
# 处理节点 b
|
||||
if result.get("b") and isinstance(result["b"], dict):
|
||||
node_b = result["b"]
|
||||
node_id = str(node_b["id"])
|
||||
if node_id not in nodes_dict and "properties" in node_b:
|
||||
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||
id=node_id,
|
||||
labels=[node_b["properties"]["entity_id"]],
|
||||
properties=node_b["properties"],
|
||||
)
|
||||
|
||||
# 处理边 r
|
||||
if result.get("r") and isinstance(result["r"], dict):
|
||||
edge = result["r"]
|
||||
edge_id = str(edge["id"])
|
||||
if edge_id not in edges_dict:
|
||||
edges_dict[edge_id] = KnowledgeGraphEdge(
|
||||
id=edge_id,
|
||||
type=edge["label"],
|
||||
source=str(edge["start_id"]),
|
||||
target=str(edge["end_id"]),
|
||||
properties=edge["properties"],
|
||||
)
|
||||
|
||||
kg = KnowledgeGraph(
|
||||
nodes=list(nodes_dict.values()),
|
||||
edges=list(edges_dict.values()),
|
||||
is_truncated=is_truncated,
|
||||
)
|
||||
else:
|
||||
# For single node query, use BFS algorithm
|
||||
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
||||
)
|
||||
else:
|
||||
# For single node query, use BFS algorithm
|
||||
# 非通配符查询,使用 BFS 算法
|
||||
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
|
||||
logger.info(
|
||||
f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
||||
)
|
||||
return kg
|
||||
|
||||
async def drop(self) -> dict[str, str]:
|
||||
|
Reference in New Issue
Block a user