From 262c93d8da0c7dfbaf3d38b8e7d3604bb8b9cc2c Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 13 Apr 2025 01:07:39 +0800 Subject: [PATCH] Add batch query unit test for grap storage --- tests/test_graph_storage.py | 214 +++++++++++++++++++++++++++++++++++- 1 file changed, 211 insertions(+), 3 deletions(-) diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index 0103cac5..eb59594b 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -377,6 +377,207 @@ async def test_graph_advanced(storage): return False +async def test_graph_batch_operations(storage): + """ + 测试图数据库的批量操作: + 1. 使用 get_nodes_batch 批量获取多个节点的属性 + 2. 使用 node_degrees_batch 批量获取多个节点的度数 + 3. 使用 edge_degrees_batch 批量获取多个边的度数 + 4. 使用 get_edges_batch 批量获取多个边的属性 + 5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边 + """ + try: + # 清理之前的测试数据 + print("清理之前的测试数据...\n") + await storage.drop() + + # 1. 插入测试数据 + # 插入节点1: 人工智能 + node1_id = "人工智能" + node1_data = { + "entity_id": node1_id, + "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。", + "keywords": "AI,机器学习,深度学习", + "entity_type": "技术领域", + } + print(f"插入节点1: {node1_id}") + await storage.upsert_node(node1_id, node1_data) + + # 插入节点2: 机器学习 + node2_id = "机器学习" + node2_data = { + "entity_id": node2_id, + "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。", + "keywords": "监督学习,无监督学习,强化学习", + "entity_type": "技术领域", + } + print(f"插入节点2: {node2_id}") + await storage.upsert_node(node2_id, node2_data) + + # 插入节点3: 深度学习 + node3_id = "深度学习" + node3_data = { + "entity_id": node3_id, + "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。", + "keywords": "神经网络,CNN,RNN", + "entity_type": "技术领域", + } + print(f"插入节点3: {node3_id}") + await storage.upsert_node(node3_id, node3_data) + + # 插入节点4: 自然语言处理 + node4_id = "自然语言处理" + node4_data = { + "entity_id": node4_id, + "description": "自然语言处理是人工智能的一个分支,专注于使计算机理解和处理人类语言。", + "keywords": "NLP,文本分析,语言模型", + "entity_type": "技术领域", + } + print(f"插入节点4: {node4_id}") + await storage.upsert_node(node4_id, node4_data) + + # 插入节点5: 计算机视觉 + node5_id = "计算机视觉" + node5_data = { + "entity_id": node5_id, + "description": "计算机视觉是人工智能的一个分支,专注于使计算机能够从图像或视频中获取信息。", + "keywords": "CV,图像识别,目标检测", + "entity_type": "技术领域", + } + print(f"插入节点5: {node5_id}") + await storage.upsert_node(node5_id, node5_data) + + # 插入边1: 人工智能 -> 机器学习 + edge1_data = { + "relationship": "包含", + "weight": 1.0, + "description": "人工智能领域包含机器学习这个子领域", + } + print(f"插入边1: {node1_id} -> {node2_id}") + await storage.upsert_edge(node1_id, node2_id, edge1_data) + + # 插入边2: 机器学习 -> 深度学习 + edge2_data = { + "relationship": "包含", + "weight": 1.0, + "description": "机器学习领域包含深度学习这个子领域", + } + print(f"插入边2: {node2_id} -> {node3_id}") + await storage.upsert_edge(node2_id, node3_id, edge2_data) + + # 插入边3: 人工智能 -> 自然语言处理 + edge3_data = { + "relationship": "包含", + "weight": 1.0, + "description": "人工智能领域包含自然语言处理这个子领域", + } + print(f"插入边3: {node1_id} -> {node4_id}") + await storage.upsert_edge(node1_id, node4_id, edge3_data) + + # 插入边4: 人工智能 -> 计算机视觉 + edge4_data = { + "relationship": "包含", + "weight": 1.0, + "description": "人工智能领域包含计算机视觉这个子领域", + } + print(f"插入边4: {node1_id} -> {node5_id}") + await storage.upsert_edge(node1_id, node5_id, edge4_data) + + # 插入边5: 深度学习 -> 自然语言处理 + edge5_data = { + "relationship": "应用于", + "weight": 0.8, + "description": "深度学习技术应用于自然语言处理领域", + } + print(f"插入边5: {node3_id} -> {node4_id}") + await storage.upsert_edge(node3_id, node4_id, edge5_data) + + # 插入边6: 深度学习 -> 计算机视觉 + edge6_data = { + "relationship": "应用于", + "weight": 0.8, + "description": "深度学习技术应用于计算机视觉领域", + } + print(f"插入边6: {node3_id} -> {node5_id}") + await storage.upsert_edge(node3_id, node5_id, edge6_data) + + # 2. 测试 get_nodes_batch - 批量获取多个节点的属性 + print("== 测试 get_nodes_batch") + node_ids = [node1_id, node2_id, node3_id] + nodes_dict = await storage.get_nodes_batch(node_ids) + print(f"批量获取节点属性结果: {nodes_dict.keys()}") + assert len(nodes_dict) == 3, f"应返回3个节点,实际返回 {len(nodes_dict)} 个" + assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中" + assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中" + assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中" + assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配" + assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配" + assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配" + + # 3. 测试 node_degrees_batch - 批量获取多个节点的度数 + print("== 测试 node_degrees_batch") + node_degrees = await storage.node_degrees_batch(node_ids) + print(f"批量获取节点度数结果: {node_degrees}") + assert len(node_degrees) == 3, f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个" + assert node1_id in node_degrees, f"{node1_id} 应在返回结果中" + assert node2_id in node_degrees, f"{node2_id} 应在返回结果中" + assert node3_id in node_degrees, f"{node3_id} 应在返回结果中" + assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}" + assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}" + assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}" + + # 4. 测试 edge_degrees_batch - 批量获取多个边的度数 + print("== 测试 edge_degrees_batch") + edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)] + edge_degrees = await storage.edge_degrees_batch(edges) + print(f"批量获取边度数结果: {edge_degrees}") + assert len(edge_degrees) == 3, f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条" + assert (node1_id, node2_id) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中" + assert (node2_id, node3_id) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中" + assert (node3_id, node4_id) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中" + # 验证边的度数是否正确(源节点度数 + 目标节点度数) + assert edge_degrees[(node1_id, node2_id)] == 5, f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}" + assert edge_degrees[(node2_id, node3_id)] == 5, f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}" + assert edge_degrees[(node3_id, node4_id)] == 5, f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}" + + # 5. 测试 get_edges_batch - 批量获取多个边的属性 + print("== 测试 get_edges_batch") + # 将元组列表转换为Neo4j风格的字典列表 + edge_dicts = [{"src": src, "tgt": tgt} for src, tgt in edges] + edges_dict = await storage.get_edges_batch(edge_dicts) + print(f"批量获取边属性结果: {edges_dict.keys()}") + assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条" + assert (node1_id, node2_id) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中" + assert (node2_id, node3_id) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中" + assert (node3_id, node4_id) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中" + assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"边 {node1_id} -> {node2_id} 关系不匹配" + assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"边 {node2_id} -> {node3_id} 关系不匹配" + assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"边 {node3_id} -> {node4_id} 关系不匹配" + + # 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边 + print("== 测试 get_nodes_edges_batch") + nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id]) + print(f"批量获取节点边结果: {nodes_edges.keys()}") + assert len(nodes_edges) == 2, f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个" + assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中" + assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中" + assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条" + assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条" + + # 7. 清理数据 + print("== 测试 drop") + result = await storage.drop() + print(f"清理结果: {result}") + assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}" + + print("\n批量操作测试完成") + return True + + except Exception as e: + ASCIIColors.red(f"测试过程中发生错误: {str(e)}") + return False + + async def main(): """主函数""" # 显示程序标题 @@ -411,21 +612,28 @@ async def main(): ASCIIColors.yellow("\n请选择测试类型:") ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)") ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)") - ASCIIColors.white("3. 全部测试") + ASCIIColors.white("3. 批量操作测试 (批量获取节点、边属性和度数等)") + ASCIIColors.white("4. 全部测试") - choice = input("\n请输入选项 (1/2/3): ") + choice = input("\n请输入选项 (1/2/3/4): ") if choice == "1": await test_graph_basic(storage) elif choice == "2": await test_graph_advanced(storage) elif choice == "3": + await test_graph_batch_operations(storage) + elif choice == "4": ASCIIColors.cyan("\n=== 开始基本测试 ===") basic_result = await test_graph_basic(storage) if basic_result: ASCIIColors.cyan("\n=== 开始高级测试 ===") - await test_graph_advanced(storage) + advanced_result = await test_graph_advanced(storage) + + if advanced_result: + ASCIIColors.cyan("\n=== 开始批量操作测试 ===") + await test_graph_batch_operations(storage) else: ASCIIColors.red("无效的选项")