Fix linting

This commit is contained in:
yangdx
2025-04-15 12:34:04 +08:00
parent 22b03ee1bb
commit 1de74c9228
7 changed files with 249 additions and 150 deletions

View File

@@ -388,7 +388,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[node_id] = degree result[node_id] = degree
return result return result
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch """Edge degrees as a batch using UNWIND also uses node_degrees_batch
Default implementation calculates edge degrees one by one. Default implementation calculates edge degrees one by one.
@@ -401,7 +403,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[(src_id, tgt_id)] = degree result[(src_id, tgt_id)] = degree
return result return result
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""Get edges as a batch using UNWIND """Get edges as a batch using UNWIND
Default implementation fetches edges one by one. Default implementation fetches edges one by one.
@@ -417,7 +421,9 @@ class BaseGraphStorage(StorageNameSpace, ABC):
result[(src_id, tgt_id)] = edge result[(src_id, tgt_id)] = edge
return result return result
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""Get nodes edges as a batch using UNWIND """Get nodes edges as a batch using UNWIND
Default implementation fetches node edges one by one. Default implementation fetches node edges one by one.

View File

@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
node_dict = dict(node) node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property # Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict: if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes[entity_id] = node_dict nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully await result.consume() # Make sure to consume the result fully
return nodes return nodes
@@ -437,7 +439,9 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
""" """
Calculate the combined degree for each edge (sum of the source and target node degrees) Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch. in batch using the already implemented node_degrees_batch.
@@ -547,7 +551,9 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
""" """
Retrieve edge properties for multiple (src, tgt) pairs in one query. Retrieve edge properties for multiple (src, tgt) pairs in one query.
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
if edges and len(edges) > 0: if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults # Ensure required keys exist with defaults
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items(): for key, default in {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}.items():
if key not in edge_props: if key not in edge_props:
edge_props[key] = default edge_props[key] = default
edges_dict[(src, tgt)] = edge_props edges_dict[(src, tgt)] = edge_props
else: else:
# No edge found set default edge properties # No edge found set default edge properties
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None} edges_dict[(src, tgt)] = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
await result.consume() await result.consume()
return edges_dict return edges_dict
@@ -644,7 +660,9 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise raise
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
""" """
Batch retrieve edges for multiple nodes in one query using UNWIND. Batch retrieve edges for multiple nodes in one query using UNWIND.
@@ -654,7 +672,9 @@ class Neo4JStorage(BaseGraphStorage):
Returns: Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target). A dictionary mapping each node ID to its list of edge tuples (source, target).
""" """
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """ query = """
UNWIND $node_ids AS id UNWIND $node_ids AS id
MATCH (n:base {entity_id: id}) MATCH (n:base {entity_id: id})

View File

@@ -1472,16 +1472,15 @@ class PGGraphStorage(BaseGraphStorage):
return {} return {}
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
MATCH (n:base {entity_id: node_id}) MATCH (n:base {entity_id: node_id})
RETURN node_id, n RETURN node_id, n
$$) AS (node_id text, n agtype)""" % ( $$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
self.graph_name,
formatted_ids
)
results = await self._query(query) results = await self._query(query)
@@ -1492,7 +1491,9 @@ class PGGraphStorage(BaseGraphStorage):
node_dict = result["n"]["properties"] node_dict = result["n"]["properties"]
# Remove the 'base' label if present in a 'labels' property # Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict: if "labels" in node_dict:
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes_dict[result["node_id"]] = node_dict nodes_dict[result["node_id"]] = node_dict
return nodes_dict return nodes_dict
@@ -1512,7 +1513,9 @@ class PGGraphStorage(BaseGraphStorage):
return {} return {}
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
@@ -1521,7 +1524,7 @@ class PGGraphStorage(BaseGraphStorage):
RETURN node_id, count(r) AS degree RETURN node_id, count(r) AS degree
$$) AS (node_id text, degree bigint)""" % ( $$) AS (node_id text, degree bigint)""" % (
self.graph_name, self.graph_name,
formatted_ids formatted_ids,
) )
results = await self._query(query) results = await self._query(query)
@@ -1539,7 +1542,9 @@ class PGGraphStorage(BaseGraphStorage):
return degrees_dict return degrees_dict
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]: async def edge_degrees_batch(
self, edges: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
""" """
Calculate the combined degree for each edge (sum of the source and target node degrees) Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch. in batch using the already implemented node_degrees_batch.
@@ -1570,7 +1575,9 @@ class PGGraphStorage(BaseGraphStorage):
return edge_degrees_dict return edge_degrees_dict
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
""" """
Retrieve edge properties for multiple (src, tgt) pairs in one query. Retrieve edge properties for multiple (src, tgt) pairs in one query.
@@ -1587,8 +1594,8 @@ class PGGraphStorage(BaseGraphStorage):
src_nodes = [] src_nodes = []
tgt_nodes = [] tgt_nodes = []
for pair in pairs: for pair in pairs:
src_nodes.append(pair["src"].replace('"', '')) src_nodes.append(pair["src"].replace('"', ""))
tgt_nodes.append(pair["tgt"].replace('"', '')) tgt_nodes.append(pair["tgt"].replace('"', ""))
# 构建查询,使用数组索引来匹配源节点和目标节点 # 构建查询,使用数组索引来匹配源节点和目标节点
src_array = ", ".join([f'"{src}"' for src in src_nodes]) src_array = ", ".join([f'"{src}"' for src in src_nodes])
@@ -1607,11 +1614,15 @@ class PGGraphStorage(BaseGraphStorage):
edges_dict = {} edges_dict = {}
for result in results: for result in results:
if result["source"] and result["target"] and result["edge_properties"]: if result["source"] and result["target"] and result["edge_properties"]:
edges_dict[(result["source"], result["target"])] = result["edge_properties"] edges_dict[(result["source"], result["target"])] = result[
"edge_properties"
]
return edges_dict return edges_dict
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
""" """
Get all edges for multiple nodes in a single batch operation. Get all edges for multiple nodes in a single batch operation.
@@ -1625,7 +1636,9 @@ class PGGraphStorage(BaseGraphStorage):
return {} return {}
# Format node IDs for the query # Format node IDs for the query
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids]) formatted_ids = ", ".join(
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
)
query = """SELECT * FROM cypher('%s', $$ query = """SELECT * FROM cypher('%s', $$
UNWIND [%s] AS node_id UNWIND [%s] AS node_id
@@ -1634,7 +1647,7 @@ class PGGraphStorage(BaseGraphStorage):
RETURN node_id, connected.entity_id AS connected_id RETURN node_id, connected.entity_id AS connected_id
$$) AS (node_id text, connected_id text)""" % ( $$) AS (node_id text, connected_id text)""" % (
self.graph_name, self.graph_name,
formatted_ids formatted_ids,
) )
results = await self._query(query) results = await self._query(query)

View File

@@ -1330,7 +1330,7 @@ async def _get_node_data(
# Call the batch node retrieval and degree functions concurrently. # Call the batch node retrieval and degree functions concurrently.
nodes_dict, degrees_dict = await asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(node_ids), knowledge_graph_inst.get_nodes_batch(node_ids),
knowledge_graph_inst.node_degrees_batch(node_ids) knowledge_graph_inst.node_degrees_batch(node_ids),
) )
# Now, if you need the node data and degree in order: # Now, if you need the node data and degree in order:
@@ -1474,8 +1474,12 @@ async def _find_most_related_text_unit_from_entities(
all_one_hop_nodes = list(all_one_hop_nodes) all_one_hop_nodes = list(all_one_hop_nodes)
# Batch retrieve one-hop node data using get_nodes_batch # Batch retrieve one-hop node data using get_nodes_batch
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes) all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes] all_one_hop_nodes
)
all_one_hop_nodes_data = [
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
]
# Add null check for node data # Add null check for node data
all_one_hop_text_units_lookup = { all_one_hop_text_units_lookup = {
@@ -1575,7 +1579,7 @@ async def _find_most_related_edges_from_entities(
# Call the batched functions concurrently. # Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather( edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples) knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
) )
# Reconstruct edge_datas list in the same order as the deduplicated results. # Reconstruct edge_datas list in the same order as the deduplicated results.
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
} }
all_edges_data.append(combined) all_edges_data.append(combined)
all_edges_data = sorted( all_edges_data = sorted(
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
# Call the batched functions concurrently. # Call the batched functions concurrently.
edge_data_dict, edge_degrees_dict = await asyncio.gather( edge_data_dict, edge_degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples) knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
) )
# Reconstruct edge_datas list in the same order as results. # Reconstruct edge_datas list in the same order as results.
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
# Batch approach: Retrieve nodes and their degrees concurrently with one query each. # Batch approach: Retrieve nodes and their degrees concurrently with one query each.
nodes_dict, degrees_dict = await asyncio.gather( nodes_dict, degrees_dict = await asyncio.gather(
knowledge_graph_inst.get_nodes_batch(entity_names), knowledge_graph_inst.get_nodes_batch(entity_names),
knowledge_graph_inst.node_degrees_batch(entity_names) knowledge_graph_inst.node_degrees_batch(entity_names),
) )
# Rebuild the list in the same order as entity_names # Rebuild the list in the same order as entity_names

View File

@@ -510,35 +510,66 @@ async def test_graph_batch_operations(storage):
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中" assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中" assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中" assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配" assert (
assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配" nodes_dict[node1_id]["description"] == node1_data["description"]
assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配" ), f"{node1_id} 描述不匹配"
assert (
nodes_dict[node2_id]["description"] == node2_data["description"]
), f"{node2_id} 描述不匹配"
assert (
nodes_dict[node3_id]["description"] == node3_data["description"]
), f"{node3_id} 描述不匹配"
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数 # 3. 测试 node_degrees_batch - 批量获取多个节点的度数
print("== 测试 node_degrees_batch") print("== 测试 node_degrees_batch")
node_degrees = await storage.node_degrees_batch(node_ids) node_degrees = await storage.node_degrees_batch(node_ids)
print(f"批量获取节点度数结果: {node_degrees}") print(f"批量获取节点度数结果: {node_degrees}")
assert len(node_degrees) == 3, f"应返回3个节点的度数实际返回 {len(node_degrees)}" assert (
len(node_degrees) == 3
), f"应返回3个节点的度数实际返回 {len(node_degrees)}"
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中" assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中" assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中" assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3实际为 {node_degrees[node1_id]}" assert (
assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2实际为 {node_degrees[node2_id]}" node_degrees[node1_id] == 3
assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3实际为 {node_degrees[node3_id]}" ), f"{node1_id} 度数应为3实际为 {node_degrees[node1_id]}"
assert (
node_degrees[node2_id] == 2
), f"{node2_id} 度数应为2实际为 {node_degrees[node2_id]}"
assert (
node_degrees[node3_id] == 3
), f"{node3_id} 度数应为3实际为 {node_degrees[node3_id]}"
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数 # 4. 测试 edge_degrees_batch - 批量获取多个边的度数
print("== 测试 edge_degrees_batch") print("== 测试 edge_degrees_batch")
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)] edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
edge_degrees = await storage.edge_degrees_batch(edges) edge_degrees = await storage.edge_degrees_batch(edges)
print(f"批量获取边度数结果: {edge_degrees}") print(f"批量获取边度数结果: {edge_degrees}")
assert len(edge_degrees) == 3, f"应返回3条边的度数实际返回 {len(edge_degrees)}" assert (
assert (node1_id, node2_id) in edge_degrees, f"{node1_id} -> {node2_id} 应在返回结果中" len(edge_degrees) == 3
assert (node2_id, node3_id) in edge_degrees, f"{node2_id} -> {node3_id} 应在返回结果中" ), f"应返回3条边的度数实际返回 {len(edge_degrees)}"
assert (node3_id, node4_id) in edge_degrees, f"{node3_id} -> {node4_id} 应在返回结果中" assert (
node1_id,
node2_id,
) in edge_degrees, f"{node1_id} -> {node2_id} 应在返回结果中"
assert (
node2_id,
node3_id,
) in edge_degrees, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (
node3_id,
node4_id,
) in edge_degrees, f"{node3_id} -> {node4_id} 应在返回结果中"
# 验证边的度数是否正确(源节点度数 + 目标节点度数) # 验证边的度数是否正确(源节点度数 + 目标节点度数)
assert edge_degrees[(node1_id, node2_id)] == 5, f"{node1_id} -> {node2_id} 度数应为5实际为 {edge_degrees[(node1_id, node2_id)]}" assert (
assert edge_degrees[(node2_id, node3_id)] == 5, f"{node2_id} -> {node3_id} 度数应为5实际为 {edge_degrees[(node2_id, node3_id)]}" edge_degrees[(node1_id, node2_id)] == 5
assert edge_degrees[(node3_id, node4_id)] == 5, f"{node3_id} -> {node4_id} 度数应为5实际为 {edge_degrees[(node3_id, node4_id)]}" ), f"{node1_id} -> {node2_id} 度数应为5实际为 {edge_degrees[(node1_id, node2_id)]}"
assert (
edge_degrees[(node2_id, node3_id)] == 5
), f"{node2_id} -> {node3_id} 度数应为5实际为 {edge_degrees[(node2_id, node3_id)]}"
assert (
edge_degrees[(node3_id, node4_id)] == 5
), f"{node3_id} -> {node4_id} 度数应为5实际为 {edge_degrees[(node3_id, node4_id)]}"
# 5. 测试 get_edges_batch - 批量获取多个边的属性 # 5. 测试 get_edges_batch - 批量获取多个边的属性
print("== 测试 get_edges_batch") print("== 测试 get_edges_batch")
@@ -547,28 +578,54 @@ async def test_graph_batch_operations(storage):
edges_dict = await storage.get_edges_batch(edge_dicts) edges_dict = await storage.get_edges_batch(edge_dicts)
print(f"批量获取边属性结果: {edges_dict.keys()}") print(f"批量获取边属性结果: {edges_dict.keys()}")
assert len(edges_dict) == 3, f"应返回3条边的属性实际返回 {len(edges_dict)}" assert len(edges_dict) == 3, f"应返回3条边的属性实际返回 {len(edges_dict)}"
assert (node1_id, node2_id) in edges_dict, f"{node1_id} -> {node2_id} 应在返回结果中" assert (
assert (node2_id, node3_id) in edges_dict, f"{node2_id} -> {node3_id} 应在返回结果中" node1_id,
assert (node3_id, node4_id) in edges_dict, f"{node3_id} -> {node4_id} 应在返回结果中" node2_id,
assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"{node1_id} -> {node2_id} 关系不匹配" ) in edges_dict, f"{node1_id} -> {node2_id} 应在返回结果中"
assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"{node2_id} -> {node3_id} 关系不匹配" assert (
assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"{node3_id} -> {node4_id} 关系不匹配" node2_id,
node3_id,
) in edges_dict, f"{node2_id} -> {node3_id} 应在返回结果中"
assert (
node3_id,
node4_id,
) in edges_dict, f"{node3_id} -> {node4_id} 应在返回结果中"
assert (
edges_dict[(node1_id, node2_id)]["relationship"]
== edge1_data["relationship"]
), f"{node1_id} -> {node2_id} 关系不匹配"
assert (
edges_dict[(node2_id, node3_id)]["relationship"]
== edge2_data["relationship"]
), f"{node2_id} -> {node3_id} 关系不匹配"
assert (
edges_dict[(node3_id, node4_id)]["relationship"]
== edge5_data["relationship"]
), f"{node3_id} -> {node4_id} 关系不匹配"
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边 # 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
print("== 测试 get_nodes_edges_batch") print("== 测试 get_nodes_edges_batch")
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id]) nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
print(f"批量获取节点边结果: {nodes_edges.keys()}") print(f"批量获取节点边结果: {nodes_edges.keys()}")
assert len(nodes_edges) == 2, f"应返回2个节点的边实际返回 {len(nodes_edges)}" assert (
len(nodes_edges) == 2
), f"应返回2个节点的边实际返回 {len(nodes_edges)}"
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中" assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中" assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边实际有 {len(nodes_edges[node1_id])}" assert (
assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边实际有 {len(nodes_edges[node3_id])}" len(nodes_edges[node1_id]) == 3
), f"{node1_id} 应有3条边实际有 {len(nodes_edges[node1_id])}"
assert (
len(nodes_edges[node3_id]) == 3
), f"{node3_id} 应有3条边实际有 {len(nodes_edges[node3_id])}"
# 7. 清理数据 # 7. 清理数据
print("== 测试 drop") print("== 测试 drop")
result = await storage.drop() result = await storage.drop()
print(f"清理结果: {result}") print(f"清理结果: {result}")
assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}" assert (
result["status"] == "success"
), f"清理应成功,实际状态为 {result['status']}"
print("\n批量操作测试完成") print("\n批量操作测试完成")
return True return True