Fix max_nodes not working in graph queries when using the '*' wildcard.
This commit is contained in:
@@ -1886,85 +1886,87 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
# Get all nodes at the current depth
|
# Get all nodes at the current depth
|
||||||
current_level_nodes = []
|
current_level_nodes = []
|
||||||
current_depth = None
|
current_depth = None
|
||||||
|
|
||||||
# Determine current depth
|
# Determine current depth
|
||||||
if queue:
|
if queue:
|
||||||
current_depth = queue[0][1]
|
current_depth = queue[0][1]
|
||||||
|
|
||||||
# Extract all nodes at current depth from the queue
|
# Extract all nodes at current depth from the queue
|
||||||
while queue and queue[0][1] == current_depth:
|
while queue and queue[0][1] == current_depth:
|
||||||
node, depth = queue.popleft()
|
node, depth = queue.popleft()
|
||||||
if depth > max_depth:
|
if depth > max_depth:
|
||||||
continue
|
continue
|
||||||
current_level_nodes.append(node)
|
current_level_nodes.append(node)
|
||||||
|
|
||||||
if not current_level_nodes:
|
if not current_level_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check depth limit
|
# Check depth limit
|
||||||
if current_depth > max_depth:
|
if current_depth > max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Prepare node IDs list
|
# Prepare node IDs list
|
||||||
node_ids = [node.labels[0] for node in current_level_nodes]
|
node_ids = [node.labels[0] for node in current_level_nodes]
|
||||||
formatted_ids = ", ".join([f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids])
|
formatted_ids = ", ".join(
|
||||||
|
[f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids]
|
||||||
|
)
|
||||||
|
|
||||||
# Construct batch query for outgoing edges
|
# Construct batch query for outgoing edges
|
||||||
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
UNWIND [{formatted_ids}] AS node_id
|
UNWIND [{formatted_ids}] AS node_id
|
||||||
MATCH (n:base {{entity_id: node_id}})
|
MATCH (n:base {{entity_id: node_id}})
|
||||||
OPTIONAL MATCH (n)-[r]->(neighbor:base)
|
OPTIONAL MATCH (n)-[r]->(neighbor:base)
|
||||||
RETURN node_id AS current_id,
|
RETURN node_id AS current_id,
|
||||||
id(n) AS current_internal_id,
|
id(n) AS current_internal_id,
|
||||||
id(neighbor) AS neighbor_internal_id,
|
id(neighbor) AS neighbor_internal_id,
|
||||||
neighbor.entity_id AS neighbor_id,
|
neighbor.entity_id AS neighbor_id,
|
||||||
id(r) AS edge_id,
|
id(r) AS edge_id,
|
||||||
r,
|
r,
|
||||||
neighbor,
|
neighbor,
|
||||||
true AS is_outgoing
|
true AS is_outgoing
|
||||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||||
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||||
|
|
||||||
# Construct batch query for incoming edges
|
# Construct batch query for incoming edges
|
||||||
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
UNWIND [{formatted_ids}] AS node_id
|
UNWIND [{formatted_ids}] AS node_id
|
||||||
MATCH (n:base {{entity_id: node_id}})
|
MATCH (n:base {{entity_id: node_id}})
|
||||||
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
|
OPTIONAL MATCH (n)<-[r]-(neighbor:base)
|
||||||
RETURN node_id AS current_id,
|
RETURN node_id AS current_id,
|
||||||
id(n) AS current_internal_id,
|
id(n) AS current_internal_id,
|
||||||
id(neighbor) AS neighbor_internal_id,
|
id(neighbor) AS neighbor_internal_id,
|
||||||
neighbor.entity_id AS neighbor_id,
|
neighbor.entity_id AS neighbor_id,
|
||||||
id(r) AS edge_id,
|
id(r) AS edge_id,
|
||||||
r,
|
r,
|
||||||
neighbor,
|
neighbor,
|
||||||
false AS is_outgoing
|
false AS is_outgoing
|
||||||
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
$$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint,
|
||||||
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)"""
|
||||||
|
|
||||||
# Execute queries
|
# Execute queries
|
||||||
outgoing_results = await self._query(outgoing_query)
|
outgoing_results = await self._query(outgoing_query)
|
||||||
incoming_results = await self._query(incoming_query)
|
incoming_results = await self._query(incoming_query)
|
||||||
|
|
||||||
# Combine results
|
# Combine results
|
||||||
neighbors = outgoing_results + incoming_results
|
neighbors = outgoing_results + incoming_results
|
||||||
|
|
||||||
# Create mapping from node ID to node object
|
# Create mapping from node ID to node object
|
||||||
node_map = {node.labels[0]: node for node in current_level_nodes}
|
node_map = {node.labels[0]: node for node in current_level_nodes}
|
||||||
|
|
||||||
# Process all results in a single loop
|
# Process all results in a single loop
|
||||||
for record in neighbors:
|
for record in neighbors:
|
||||||
if not record.get("neighbor") or not record.get("r"):
|
if not record.get("neighbor") or not record.get("r"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get current node information
|
# Get current node information
|
||||||
current_entity_id = record["current_id"]
|
current_entity_id = record["current_id"]
|
||||||
current_node = node_map[current_entity_id]
|
current_node = node_map[current_entity_id]
|
||||||
|
|
||||||
# Get neighbor node information
|
# Get neighbor node information
|
||||||
neighbor_entity_id = record["neighbor_id"]
|
neighbor_entity_id = record["neighbor_id"]
|
||||||
neighbor_internal_id = str(record["neighbor_internal_id"])
|
neighbor_internal_id = str(record["neighbor_internal_id"])
|
||||||
is_outgoing = record["is_outgoing"]
|
is_outgoing = record["is_outgoing"]
|
||||||
|
|
||||||
# Determine edge direction
|
# Determine edge direction
|
||||||
if is_outgoing:
|
if is_outgoing:
|
||||||
source_id = current_node.id
|
source_id = current_node.id
|
||||||
@@ -1972,25 +1974,25 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
else:
|
else:
|
||||||
source_id = neighbor_internal_id
|
source_id = neighbor_internal_id
|
||||||
target_id = current_node.id
|
target_id = current_node.id
|
||||||
|
|
||||||
if not neighbor_entity_id:
|
if not neighbor_entity_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get edge and node information
|
# Get edge and node information
|
||||||
b_node = record["neighbor"]
|
b_node = record["neighbor"]
|
||||||
rel = record["r"]
|
rel = record["r"]
|
||||||
edge_id = str(record["edge_id"])
|
edge_id = str(record["edge_id"])
|
||||||
|
|
||||||
# Create neighbor node object
|
# Create neighbor node object
|
||||||
neighbor_node = KnowledgeGraphNode(
|
neighbor_node = KnowledgeGraphNode(
|
||||||
id=neighbor_internal_id,
|
id=neighbor_internal_id,
|
||||||
labels=[neighbor_entity_id],
|
labels=[neighbor_entity_id],
|
||||||
properties=b_node["properties"],
|
properties=b_node["properties"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
# Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge
|
||||||
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
|
sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id]))
|
||||||
|
|
||||||
# Create edge object
|
# Create edge object
|
||||||
edge = KnowledgeGraphEdge(
|
edge = KnowledgeGraphEdge(
|
||||||
id=edge_id,
|
id=edge_id,
|
||||||
@@ -1999,7 +2001,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
target=target_id,
|
target=target_id,
|
||||||
properties=rel["properties"],
|
properties=rel["properties"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if neighbor_internal_id in visited_node_ids:
|
if neighbor_internal_id in visited_node_ids:
|
||||||
# Add backward edge if neighbor node is already visited
|
# Add backward edge if neighbor node is already visited
|
||||||
if (
|
if (
|
||||||
@@ -2015,10 +2017,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
result.nodes.append(neighbor_node)
|
result.nodes.append(neighbor_node)
|
||||||
visited_nodes.add(neighbor_entity_id)
|
visited_nodes.add(neighbor_entity_id)
|
||||||
visited_node_ids.add(neighbor_internal_id)
|
visited_node_ids.add(neighbor_internal_id)
|
||||||
|
|
||||||
# Add node to queue with incremented depth
|
# Add node to queue with incremented depth
|
||||||
queue.append((neighbor_node, current_depth + 1))
|
queue.append((neighbor_node, current_depth + 1))
|
||||||
|
|
||||||
# Add forward edge
|
# Add forward edge
|
||||||
if (
|
if (
|
||||||
edge_id not in visited_edges
|
edge_id not in visited_edges
|
||||||
@@ -2051,6 +2053,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
|
||||||
indicating whether the graph was truncated due to max_nodes limit
|
indicating whether the graph was truncated due to max_nodes limit
|
||||||
"""
|
"""
|
||||||
|
# 初始化 kg 变量,确保在所有情况下都有定义
|
||||||
|
kg = KnowledgeGraph()
|
||||||
|
|
||||||
# Handle wildcard query - get all nodes
|
# Handle wildcard query - get all nodes
|
||||||
if node_label == "*":
|
if node_label == "*":
|
||||||
# First check total node count to determine if graph should be truncated
|
# First check total node count to determine if graph should be truncated
|
||||||
@@ -2063,57 +2068,91 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
total_nodes = count_result[0]["total_nodes"] if count_result else 0
|
||||||
is_truncated = total_nodes > max_nodes
|
is_truncated = total_nodes > max_nodes
|
||||||
|
|
||||||
# Get nodes and edges
|
# Get max_nodes with highest degrees
|
||||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
MATCH (node:base)
|
MATCH (n:base)
|
||||||
OPTIONAL MATCH (node)-[r]->()
|
OPTIONAL MATCH (n)-[r]->()
|
||||||
RETURN collect(distinct node) AS n, collect(distinct r) AS r
|
RETURN id(n) as node_id, count(r) as degree
|
||||||
LIMIT {max_nodes}
|
$$) AS (node_id BIGINT, degree BIGINT)
|
||||||
$$) AS (n agtype, r agtype)"""
|
ORDER BY degree DESC
|
||||||
|
LIMIT {max_nodes}"""
|
||||||
|
node_results = await self._query(query_nodes)
|
||||||
|
|
||||||
results = await self._query(query)
|
node_ids = [str(result["node_id"]) for result in node_results]
|
||||||
|
|
||||||
# Process query results, deduplicate nodes and edges
|
logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}")
|
||||||
nodes_dict = {}
|
|
||||||
edges_dict = {}
|
|
||||||
|
|
||||||
for result in results:
|
if node_ids:
|
||||||
if result.get("n") and isinstance(result["n"], list):
|
formatted_ids = ", ".join(node_ids)
|
||||||
for node in result["n"]:
|
# Construct batch query for subgraph within max_nodes
|
||||||
if isinstance(node, dict) and "id" in node:
|
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||||
node_id = str(node["id"])
|
WITH [{formatted_ids}] AS node_ids
|
||||||
if node_id not in nodes_dict and "properties" in node:
|
MATCH (a)
|
||||||
nodes_dict[node_id] = KnowledgeGraphNode(
|
WHERE id(a) IN node_ids
|
||||||
id=node_id,
|
OPTIONAL MATCH (a)-[r]->(b)
|
||||||
labels=[node["properties"]["entity_id"]],
|
WHERE id(b) IN node_ids
|
||||||
properties=node["properties"],
|
RETURN a, r, b
|
||||||
)
|
$$) AS (a AGTYPE, r AGTYPE, b AGTYPE)"""
|
||||||
|
results = await self._query(query)
|
||||||
|
|
||||||
if result.get("r") and isinstance(result["r"], list):
|
# Process query results, deduplicate nodes and edges
|
||||||
for edge in result["r"]:
|
nodes_dict = {}
|
||||||
if isinstance(edge, dict) and "id" in edge:
|
edges_dict = {}
|
||||||
edge_id = str(edge["id"])
|
for result in results:
|
||||||
if edge_id not in edges_dict:
|
# 处理节点 a
|
||||||
edges_dict[edge_id] = KnowledgeGraphEdge(
|
if result.get("a") and isinstance(result["a"], dict):
|
||||||
id=edge_id,
|
node_a = result["a"]
|
||||||
type="DIRECTED",
|
node_id = str(node_a["id"])
|
||||||
source=str(edge["start_id"]),
|
if node_id not in nodes_dict and "properties" in node_a:
|
||||||
target=str(edge["end_id"]),
|
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||||
properties=edge["properties"],
|
id=node_id,
|
||||||
)
|
labels=[node_a["properties"]["entity_id"]],
|
||||||
|
properties=node_a["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
kg = KnowledgeGraph(
|
# 处理节点 b
|
||||||
nodes=list(nodes_dict.values()),
|
if result.get("b") and isinstance(result["b"], dict):
|
||||||
edges=list(edges_dict.values()),
|
node_b = result["b"]
|
||||||
is_truncated=is_truncated,
|
node_id = str(node_b["id"])
|
||||||
|
if node_id not in nodes_dict and "properties" in node_b:
|
||||||
|
nodes_dict[node_id] = KnowledgeGraphNode(
|
||||||
|
id=node_id,
|
||||||
|
labels=[node_b["properties"]["entity_id"]],
|
||||||
|
properties=node_b["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理边 r
|
||||||
|
if result.get("r") and isinstance(result["r"], dict):
|
||||||
|
edge = result["r"]
|
||||||
|
edge_id = str(edge["id"])
|
||||||
|
if edge_id not in edges_dict:
|
||||||
|
edges_dict[edge_id] = KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type=edge["label"],
|
||||||
|
source=str(edge["start_id"]),
|
||||||
|
target=str(edge["end_id"]),
|
||||||
|
properties=edge["properties"],
|
||||||
|
)
|
||||||
|
|
||||||
|
kg = KnowledgeGraph(
|
||||||
|
nodes=list(nodes_dict.values()),
|
||||||
|
edges=list(edges_dict.values()),
|
||||||
|
is_truncated=is_truncated,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For single node query, use BFS algorithm
|
||||||
|
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# For single node query, use BFS algorithm
|
# 非通配符查询,使用 BFS 算法
|
||||||
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
|
kg = await self._bfs_subgraph(node_label, max_depth, max_nodes)
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}"
|
|
||||||
)
|
|
||||||
return kg
|
return kg
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
|
Reference in New Issue
Block a user