From a3ca134e97e0ac625726e6e96396dde0c84144ed Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 17 Apr 2025 22:58:36 +0800 Subject: [PATCH 1/4] Fix special chars problem for Postgres --- lightrag/kg/postgres_impl.py | 4 ++-- lightrag/operate.py | 4 ++-- lightrag/utils.py | 7 ++++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 97d3316c..bcee7fa2 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1049,10 +1049,10 @@ class PGGraphStorage(BaseGraphStorage): Returns: Normalized node ID suitable for Cypher queries """ - # Remove quotes - normalized_id = node_id.strip('"') # Escape backslashes + normalized_id = node_id normalized_id = normalized_id.replace("\\", "\\\\") + normalized_id = normalized_id.replace('"', '\\"') return normalized_id async def initialize(self): diff --git a/lightrag/operate.py b/lightrag/operate.py index 84e1364e..f653d479 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -157,8 +157,8 @@ async def _handle_single_entity_extraction( return None # Clean and validate entity name - entity_name = clean_str(record_attributes[1]).strip('"') - if not entity_name.strip(): + entity_name = clean_str(record_attributes[1]).strip() + if not entity_name: logger.warning( f"Entity extraction error: empty entity name in: {record_attributes}" ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 37400069..913c39f3 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1385,7 +1385,12 @@ def normalize_extracted_info(name: str, is_entity=False) -> str: name = re.sub(r"(?<=[a-zA-Z0-9])\s+(?=[\u4e00-\u9fa5])", "", name) # Remove English quotation marks from the beginning and end - name = name.strip('"').strip("'") + if ( + len(name) >= 2 + and name.startswith('"') + and name.endswith('"') + ): + name = name[1:-1] if is_entity: # remove Chinese quotes From 09cca6dbe633729e70aa8a09efce847492de595f Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 17 Apr 2025 22:58:49 +0800 Subject: [PATCH 2/4] Update graph db unit test --- tests/test_graph_storage.py | 172 +++++++++++++++++++++++++++++------- 1 file changed, 140 insertions(+), 32 deletions(-) diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index f07bbbf6..7dc1ea86 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -427,19 +427,6 @@ async def test_graph_advanced(storage): assert node2_props is None, f"节点 {node2_id} 应已被删除" assert node3_props is None, f"节点 {node3_id} 应已被删除" - # 10. 测试 drop - 清理数据 - print("== 测试 drop") - result = await storage.drop() - print(f"清理结果: {result}") - assert ( - result["status"] == "success" - ), f"清理应成功,实际状态为 {result['status']}" - - # 验证清理结果 - all_labels = await storage.get_all_labels() - print(f"清理后的所有标签: {all_labels}") - assert len(all_labels) == 0, f"清理后应没有标签,实际有 {len(all_labels)}" - print("\n高级测试完成") return True @@ -773,14 +760,6 @@ async def test_graph_batch_operations(storage): print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)") - # 7. 清理数据 - print("== 测试 drop") - result = await storage.drop() - print(f"清理结果: {result}") - assert ( - result["status"] == "success" - ), f"清理应成功,实际状态为 {result['status']}" - print("\n批量操作测试完成") return True @@ -789,6 +768,136 @@ async def test_graph_batch_operations(storage): return False +async def test_graph_special_characters(storage): + """ + 测试图数据库对特殊字符的处理: + 1. 测试节点名称和描述中包含单引号、双引号和反斜杠 + 2. 测试边的描述中包含单引号、双引号和反斜杠 + 3. 验证特殊字符是否被正确保存和检索 + """ + try: + # 清理之前的测试数据 + print("清理之前的测试数据...\n") + await storage.drop() + + # 1. 测试节点名称中的特殊字符 + node1_id = "包含'单引号'的节点" + node1_data = { + "entity_id": node1_id, + "description": "这个描述包含'单引号'、\"双引号\"和\\反斜杠", + "keywords": "特殊字符,引号,转义", + "entity_type": "测试节点", + } + print(f"插入包含特殊字符的节点1: {node1_id}") + await storage.upsert_node(node1_id, node1_data) + + # 2. 测试节点名称中的双引号 + node2_id = "包含\"双引号\"的节点" + node2_data = { + "entity_id": node2_id, + "description": "这个描述同时包含'单引号'和\"双引号\"以及\\反斜杠\\路径", + "keywords": "特殊字符,引号,JSON", + "entity_type": "测试节点", + } + print(f"插入包含特殊字符的节点2: {node2_id}") + await storage.upsert_node(node2_id, node2_data) + + # 3. 测试节点名称中的反斜杠 + node3_id = "包含\\反斜杠\\的节点" + node3_data = { + "entity_id": node3_id, + "description": "这个描述包含Windows路径C:\\Program Files\\和转义字符\\n\\t", + "keywords": "反斜杠,路径,转义", + "entity_type": "测试节点", + } + print(f"插入包含特殊字符的节点3: {node3_id}") + await storage.upsert_node(node3_id, node3_data) + + # 4. 测试边描述中的特殊字符 + edge1_data = { + "relationship": "特殊'关系'", + "weight": 1.0, + "description": "这个边描述包含'单引号'、\"双引号\"和\\反斜杠", + } + print(f"插入包含特殊字符的边: {node1_id} -> {node2_id}") + await storage.upsert_edge(node1_id, node2_id, edge1_data) + + # 5. 测试边描述中的更复杂特殊字符组合 + edge2_data = { + "relationship": "复杂\"关系\"\\类型", + "weight": 0.8, + "description": "包含SQL注入尝试: SELECT * FROM users WHERE name='admin'--", + } + print(f"插入包含复杂特殊字符的边: {node2_id} -> {node3_id}") + await storage.upsert_edge(node2_id, node3_id, edge2_data) + + # 6. 验证节点特殊字符是否正确保存 + print("\n== 验证节点特殊字符") + for node_id, original_data in [ + (node1_id, node1_data), + (node2_id, node2_data), + (node3_id, node3_data), + ]: + node_props = await storage.get_node(node_id) + if node_props: + print(f"成功读取节点: {node_id}") + print(f"节点描述: {node_props.get('description', '无描述')}") + + # 验证节点ID是否正确保存 + assert node_props.get("entity_id") == node_id, f"节点ID不匹配: 期望 {node_id}, 实际 {node_props.get('entity_id')}" + + # 验证描述是否正确保存 + assert node_props.get("description") == original_data["description"], f"节点描述不匹配: 期望 {original_data['description']}, 实际 {node_props.get('description')}" + + print(f"节点 {node_id} 特殊字符验证成功") + else: + print(f"读取节点属性失败: {node_id}") + assert False, f"未能读取节点属性: {node_id}" + + # 7. 验证边特殊字符是否正确保存 + print("\n== 验证边特殊字符") + edge1_props = await storage.get_edge(node1_id, node2_id) + if edge1_props: + print(f"成功读取边: {node1_id} -> {node2_id}") + print(f"边关系: {edge1_props.get('relationship', '无关系')}") + print(f"边描述: {edge1_props.get('description', '无描述')}") + + # 验证边关系是否正确保存 + assert edge1_props.get("relationship") == edge1_data["relationship"], f"边关系不匹配: 期望 {edge1_data['relationship']}, 实际 {edge1_props.get('relationship')}" + + # 验证边描述是否正确保存 + assert edge1_props.get("description") == edge1_data["description"], f"边描述不匹配: 期望 {edge1_data['description']}, 实际 {edge1_props.get('description')}" + + print(f"边 {node1_id} -> {node2_id} 特殊字符验证成功") + else: + print(f"读取边属性失败: {node1_id} -> {node2_id}") + assert False, f"未能读取边属性: {node1_id} -> {node2_id}" + + edge2_props = await storage.get_edge(node2_id, node3_id) + if edge2_props: + print(f"成功读取边: {node2_id} -> {node3_id}") + print(f"边关系: {edge2_props.get('relationship', '无关系')}") + print(f"边描述: {edge2_props.get('description', '无描述')}") + + # 验证边关系是否正确保存 + assert edge2_props.get("relationship") == edge2_data["relationship"], f"边关系不匹配: 期望 {edge2_data['relationship']}, 实际 {edge2_props.get('relationship')}" + + # 验证边描述是否正确保存 + assert edge2_props.get("description") == edge2_data["description"], f"边描述不匹配: 期望 {edge2_data['description']}, 实际 {edge2_props.get('description')}" + + print(f"边 {node2_id} -> {node3_id} 特殊字符验证成功") + else: + print(f"读取边属性失败: {node2_id} -> {node3_id}") + assert False, f"未能读取边属性: {node2_id} -> {node3_id}" + + print("\n特殊字符测试完成,数据已保留在数据库中") + return True + + except Exception as e: + ASCIIColors.red(f"测试过程中发生错误: {str(e)}") + return False + + async def test_graph_undirected_property(storage): """ 专门测试图存储的无向图特性: @@ -973,14 +1082,6 @@ async def test_graph_undirected_property(storage): print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)") - # 7. 清理数据 - print("== 测试 drop") - result = await storage.drop() - print(f"清理结果: {result}") - assert ( - result["status"] == "success" - ), f"清理应成功,实际状态为 {result['status']}" - print("\n无向图特性测试完成") return True @@ -1025,9 +1126,10 @@ async def main(): ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)") ASCIIColors.white("3. 批量操作测试 (批量获取节点、边属性和度数等)") ASCIIColors.white("4. 无向图特性测试 (验证存储的无向图特性)") - ASCIIColors.white("5. 全部测试") + ASCIIColors.white("5. 特殊字符测试 (验证单引号、双引号和反斜杠等特殊字符)") + ASCIIColors.white("6. 全部测试") - choice = input("\n请输入选项 (1/2/3/4/5): ") + choice = input("\n请输入选项 (1/2/3/4/5/6): ") if choice == "1": await test_graph_basic(storage) @@ -1038,6 +1140,8 @@ async def main(): elif choice == "4": await test_graph_undirected_property(storage) elif choice == "5": + await test_graph_special_characters(storage) + elif choice == "6": ASCIIColors.cyan("\n=== 开始基本测试 ===") basic_result = await test_graph_basic(storage) @@ -1051,7 +1155,11 @@ async def main(): if batch_result: ASCIIColors.cyan("\n=== 开始无向图特性测试 ===") - await test_graph_undirected_property(storage) + undirected_result = await test_graph_undirected_property(storage) + + if undirected_result: + ASCIIColors.cyan("\n=== 开始特殊字符测试 ===") + await test_graph_special_characters(storage) else: ASCIIColors.red("无效的选项") From bffb9dbdb0d8d710c3e25ea38baa4744b43fa69c Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 17 Apr 2025 23:00:34 +0800 Subject: [PATCH 3/4] Fix linting --- lightrag/utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 913c39f3..dc717fb7 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1385,11 +1385,7 @@ def normalize_extracted_info(name: str, is_entity=False) -> str: name = re.sub(r"(?<=[a-zA-Z0-9])\s+(?=[\u4e00-\u9fa5])", "", name) # Remove English quotation marks from the beginning and end - if ( - len(name) >= 2 - and name.startswith('"') - and name.endswith('"') - ): + if len(name) >= 2 and name.startswith('"') and name.endswith('"'): name = name[1:-1] if is_entity: From e9dcac7caf783b12c2c4c273c6fd67662654fb9e Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 17 Apr 2025 23:09:01 +0800 Subject: [PATCH 4/4] Update graph db test --- tests/test_graph_storage.py | 78 ++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index 7dc1ea86..fb78270d 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -133,10 +133,6 @@ async def test_graph_basic(storage): 4. 使用 get_edge 读取一条边 """ try: - # 清理之前的测试数据 - print("清理之前的测试数据...") - await storage.drop() - # 1. 插入第一个节点 node1_id = "人工智能" node1_data = { @@ -251,10 +247,6 @@ async def test_graph_advanced(storage): 9. 使用 drop 清理数据 """ try: - # 清理之前的测试数据 - print("清理之前的测试数据...\n") - await storage.drop() - # 1. 插入测试数据 # 插入节点1: 人工智能 node1_id = "人工智能" @@ -445,10 +437,6 @@ async def test_graph_batch_operations(storage): 5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边 """ try: - # 清理之前的测试数据 - print("清理之前的测试数据...\n") - await storage.drop() - # 1. 插入测试数据 # 插入节点1: 人工智能 node1_id = "人工智能" @@ -776,10 +764,6 @@ async def test_graph_special_characters(storage): 3. 验证特殊字符是否被正确保存和检索 """ try: - # 清理之前的测试数据 - print("清理之前的测试数据...\n") - await storage.drop() - # 1. 测试节点名称中的特殊字符 node1_id = "包含'单引号'的节点" node1_data = { @@ -792,7 +776,7 @@ async def test_graph_special_characters(storage): await storage.upsert_node(node1_id, node1_data) # 2. 测试节点名称中的双引号 - node2_id = "包含\"双引号\"的节点" + node2_id = '包含"双引号"的节点' node2_data = { "entity_id": node2_id, "description": "这个描述同时包含'单引号'和\"双引号\"以及\\反斜杠\\路径", @@ -824,7 +808,7 @@ async def test_graph_special_characters(storage): # 5. 测试边描述中的更复杂特殊字符组合 edge2_data = { - "relationship": "复杂\"关系\"\\类型", + "relationship": '复杂"关系"\\类型', "weight": 0.8, "description": "包含SQL注入尝试: SELECT * FROM users WHERE name='admin'--", } @@ -842,13 +826,17 @@ async def test_graph_special_characters(storage): if node_props: print(f"成功读取节点: {node_id}") print(f"节点描述: {node_props.get('description', '无描述')}") - + # 验证节点ID是否正确保存 - assert node_props.get("entity_id") == node_id, f"节点ID不匹配: 期望 {node_id}, 实际 {node_props.get('entity_id')}" - + assert ( + node_props.get("entity_id") == node_id + ), f"节点ID不匹配: 期望 {node_id}, 实际 {node_props.get('entity_id')}" + # 验证描述是否正确保存 - assert node_props.get("description") == original_data["description"], f"节点描述不匹配: 期望 {original_data['description']}, 实际 {node_props.get('description')}" - + assert ( + node_props.get("description") == original_data["description"] + ), f"节点描述不匹配: 期望 {original_data['description']}, 实际 {node_props.get('description')}" + print(f"节点 {node_id} 特殊字符验证成功") else: print(f"读取节点属性失败: {node_id}") @@ -861,13 +849,17 @@ async def test_graph_special_characters(storage): print(f"成功读取边: {node1_id} -> {node2_id}") print(f"边关系: {edge1_props.get('relationship', '无关系')}") print(f"边描述: {edge1_props.get('description', '无描述')}") - + # 验证边关系是否正确保存 - assert edge1_props.get("relationship") == edge1_data["relationship"], f"边关系不匹配: 期望 {edge1_data['relationship']}, 实际 {edge1_props.get('relationship')}" - + assert ( + edge1_props.get("relationship") == edge1_data["relationship"] + ), f"边关系不匹配: 期望 {edge1_data['relationship']}, 实际 {edge1_props.get('relationship')}" + # 验证边描述是否正确保存 - assert edge1_props.get("description") == edge1_data["description"], f"边描述不匹配: 期望 {edge1_data['description']}, 实际 {edge1_props.get('description')}" - + assert ( + edge1_props.get("description") == edge1_data["description"] + ), f"边描述不匹配: 期望 {edge1_data['description']}, 实际 {edge1_props.get('description')}" + print(f"边 {node1_id} -> {node2_id} 特殊字符验证成功") else: print(f"读取边属性失败: {node1_id} -> {node2_id}") @@ -878,13 +870,17 @@ async def test_graph_special_characters(storage): print(f"成功读取边: {node2_id} -> {node3_id}") print(f"边关系: {edge2_props.get('relationship', '无关系')}") print(f"边描述: {edge2_props.get('description', '无描述')}") - + # 验证边关系是否正确保存 - assert edge2_props.get("relationship") == edge2_data["relationship"], f"边关系不匹配: 期望 {edge2_data['relationship']}, 实际 {edge2_props.get('relationship')}" - + assert ( + edge2_props.get("relationship") == edge2_data["relationship"] + ), f"边关系不匹配: 期望 {edge2_data['relationship']}, 实际 {edge2_props.get('relationship')}" + # 验证边描述是否正确保存 - assert edge2_props.get("description") == edge2_data["description"], f"边描述不匹配: 期望 {edge2_data['description']}, 实际 {edge2_props.get('description')}" - + assert ( + edge2_props.get("description") == edge2_data["description"] + ), f"边描述不匹配: 期望 {edge2_data['description']}, 实际 {edge2_props.get('description')}" + print(f"边 {node2_id} -> {node3_id} 特殊字符验证成功") else: print(f"读取边属性失败: {node2_id} -> {node3_id}") @@ -907,10 +903,6 @@ async def test_graph_undirected_property(storage): 4. 验证批量操作中的无向图特性 """ try: - # 清理之前的测试数据 - print("清理之前的测试数据...\n") - await storage.drop() - # 1. 插入测试数据 # 插入节点1: 计算机科学 node1_id = "计算机科学" @@ -1131,6 +1123,12 @@ async def main(): choice = input("\n请输入选项 (1/2/3/4/5/6): ") + # 在执行测试前清理数据 + if choice in ["1", "2", "3", "4", "5", "6"]: + ASCIIColors.yellow("\n执行测试前清理数据...") + await storage.drop() + ASCIIColors.green("数据清理完成\n") + if choice == "1": await test_graph_basic(storage) elif choice == "2": @@ -1155,8 +1153,10 @@ async def main(): if batch_result: ASCIIColors.cyan("\n=== 开始无向图特性测试 ===") - undirected_result = await test_graph_undirected_property(storage) - + undirected_result = await test_graph_undirected_property( + storage + ) + if undirected_result: ASCIIColors.cyan("\n=== 开始特殊字符测试 ===") await test_graph_special_characters(storage)