Fix linting

This commit is contained in:
yangdx
2025-04-04 03:41:05 +08:00
parent 99cce237df
commit 394a6063ba

View File

@@ -27,14 +27,16 @@ from lightrag.kg import (
STORAGE_IMPLEMENTATIONS,
STORAGE_ENV_REQUIREMENTS,
STORAGES,
verify_storage_implementation
verify_storage_implementation,
)
from lightrag.kg.shared_storage import initialize_share_data
# 模拟的嵌入函数,返回随机向量
async def mock_embedding_func(texts):
return np.random.rand(len(texts), 10) # 返回10维随机向量
def check_env_file():
"""
检查.env文件是否存在如果不存在则发出警告
@@ -52,6 +54,7 @@ def check_env_file():
return False
return True
async def initialize_graph_storage():
"""
根据环境变量初始化相应的图存储实例
@@ -65,7 +68,9 @@ async def initialize_graph_storage():
verify_storage_implementation("GRAPH_STORAGE", graph_storage_type)
except ValueError as e:
ASCIIColors.red(f"错误: {str(e)}")
ASCIIColors.yellow(f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}")
ASCIIColors.yellow(
f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
)
return None
# 检查所需的环境变量
@@ -73,7 +78,9 @@ async def initialize_graph_storage():
missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
if missing_env_vars:
ASCIIColors.red(f"错误: {graph_storage_type} 需要以下环境变量,但未设置: {', '.join(missing_env_vars)}")
ASCIIColors.red(
f"错误: {graph_storage_type} 需要以下环境变量,但未设置: {', '.join(missing_env_vars)}"
)
return None
# 动态导入相应的模块
@@ -95,7 +102,7 @@ async def initialize_graph_storage():
"vector_db_storage_cls_kwargs": {
"cosine_better_than_threshold": 0.5 # 余弦相似度阈值
},
"working_dir": os.environ.get("WORKING_DIR", "./rag_storage") # 工作目录
"working_dir": os.environ.get("WORKING_DIR", "./rag_storage"), # 工作目录
}
# 如果使用 NetworkXStorage需要先初始化 shared_storage
@@ -106,7 +113,7 @@ async def initialize_graph_storage():
storage = storage_class(
namespace="test_graph",
global_config=global_config,
embedding_func=mock_embedding_func
embedding_func=mock_embedding_func,
)
# 初始化连接
@@ -116,6 +123,7 @@ async def initialize_graph_storage():
ASCIIColors.red(f"错误: 初始化 {graph_storage_type} 失败: {str(e)}")
return None
async def test_graph_basic(storage):
"""
测试图数据库的基本操作:
@@ -135,7 +143,7 @@ async def test_graph_basic(storage):
"entity_id": node1_id,
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"keywords": "AI,机器学习,深度学习",
"entity_type": "技术领域"
"entity_type": "技术领域",
}
print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
@@ -146,7 +154,7 @@ async def test_graph_basic(storage):
"entity_id": node2_id,
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
"keywords": "监督学习,无监督学习,强化学习",
"entity_type": "技术领域"
"entity_type": "技术领域",
}
print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
@@ -155,7 +163,7 @@ async def test_graph_basic(storage):
edge_data = {
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含机器学习这个子领域"
"description": "人工智能领域包含机器学习这个子领域",
}
print(f"插入边: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge_data)
@@ -169,9 +177,15 @@ async def test_graph_basic(storage):
print(f"节点类型: {node1_props.get('entity_type', '无类型')}")
print(f"节点关键词: {node1_props.get('keywords', '无关键词')}")
# 验证返回的属性是否正确
assert node1_props.get('entity_id') == node1_id, f"节点ID不匹配: 期望 {node1_id}, 实际 {node1_props.get('entity_id')}"
assert node1_props.get('description') == node1_data['description'], "节点描述不匹配"
assert node1_props.get('entity_type') == node1_data['entity_type'], "节点类型不匹配"
assert (
node1_props.get("entity_id") == node1_id
), f"节点ID不匹配: 期望 {node1_id}, 实际 {node1_props.get('entity_id')}"
assert (
node1_props.get("description") == node1_data["description"]
), "节点描述不匹配"
assert (
node1_props.get("entity_type") == node1_data["entity_type"]
), "节点类型不匹配"
else:
print(f"读取节点属性失败: {node1_id}")
assert False, f"未能读取节点属性: {node1_id}"
@@ -185,9 +199,13 @@ async def test_graph_basic(storage):
print(f"边描述: {edge_props.get('description', '无描述')}")
print(f"边权重: {edge_props.get('weight', '无权重')}")
# 验证返回的属性是否正确
assert edge_props.get('relationship') == edge_data['relationship'], "边关系不匹配"
assert edge_props.get('description') == edge_data['description'], "边描述不匹配"
assert edge_props.get('weight') == edge_data['weight'], "权重不匹配"
assert (
edge_props.get("relationship") == edge_data["relationship"]
), "关系不匹配"
assert (
edge_props.get("description") == edge_data["description"]
), "边描述不匹配"
assert edge_props.get("weight") == edge_data["weight"], "边权重不匹配"
else:
print(f"读取边属性失败: {node1_id} -> {node2_id}")
assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
@@ -199,6 +217,7 @@ async def test_graph_basic(storage):
ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
return False
async def test_graph_advanced(storage):
"""
测试图数据库的高级操作:
@@ -224,7 +243,7 @@ async def test_graph_advanced(storage):
"entity_id": node1_id,
"description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
"keywords": "AI,机器学习,深度学习",
"entity_type": "技术领域"
"entity_type": "技术领域",
}
print(f"插入节点1: {node1_id}")
await storage.upsert_node(node1_id, node1_data)
@@ -235,7 +254,7 @@ async def test_graph_advanced(storage):
"entity_id": node2_id,
"description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
"keywords": "监督学习,无监督学习,强化学习",
"entity_type": "技术领域"
"entity_type": "技术领域",
}
print(f"插入节点2: {node2_id}")
await storage.upsert_node(node2_id, node2_data)
@@ -246,7 +265,7 @@ async def test_graph_advanced(storage):
"entity_id": node3_id,
"description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
"keywords": "神经网络,CNN,RNN",
"entity_type": "技术领域"
"entity_type": "技术领域",
}
print(f"插入节点3: {node3_id}")
await storage.upsert_node(node3_id, node3_data)
@@ -255,7 +274,7 @@ async def test_graph_advanced(storage):
edge1_data = {
"relationship": "包含",
"weight": 1.0,
"description": "人工智能领域包含机器学习这个子领域"
"description": "人工智能领域包含机器学习这个子领域",
}
print(f"插入边1: {node1_id} -> {node2_id}")
await storage.upsert_edge(node1_id, node2_id, edge1_data)
@@ -264,7 +283,7 @@ async def test_graph_advanced(storage):
edge2_data = {
"relationship": "包含",
"weight": 1.0,
"description": "机器学习领域包含深度学习这个子领域"
"description": "机器学习领域包含深度学习这个子领域",
}
print(f"插入边2: {node2_id} -> {node3_id}")
await storage.upsert_edge(node2_id, node3_id, edge2_data)
@@ -279,13 +298,17 @@ async def test_graph_advanced(storage):
print(f"== 测试 edge_degree: {node1_id} -> {node2_id}")
edge_degree = await storage.edge_degree(node1_id, node2_id)
print(f"{node1_id} -> {node2_id} 的度数: {edge_degree}")
assert edge_degree == 3, f"{node1_id} -> {node2_id} 的度数应为2实际为 {edge_degree}"
assert (
edge_degree == 3
), f"{node1_id} -> {node2_id} 的度数应为2实际为 {edge_degree}"
# 4. 测试 get_node_edges - 获取节点的所有边
print(f"== 测试 get_node_edges: {node2_id}")
node2_edges = await storage.get_node_edges(node2_id)
print(f"节点 {node2_id} 的所有边: {node2_edges}")
assert len(node2_edges) == 2, f"节点 {node2_id} 应有2条边实际有 {len(node2_edges)}"
assert (
len(node2_edges) == 2
), f"节点 {node2_id} 应有2条边实际有 {len(node2_edges)}"
# 5. 测试 get_all_labels - 获取所有标签
print("== 测试 get_all_labels")
@@ -337,7 +360,9 @@ async def test_graph_advanced(storage):
print("== 测试 drop")
result = await storage.drop()
print(f"清理结果: {result}")
assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}"
assert (
result["status"] == "success"
), f"清理应成功,实际状态为 {result['status']}"
# 验证清理结果
all_labels = await storage.get_all_labels()
@@ -351,6 +376,7 @@ async def test_graph_advanced(storage):
ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
return False
async def main():
"""主函数"""
# 显示程序标题
@@ -370,7 +396,9 @@ async def main():
# 获取图存储类型
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
ASCIIColors.magenta(f"\n当前配置的图存储类型: {graph_storage_type}")
ASCIIColors.white(f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}")
ASCIIColors.white(
f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
)
# 初始化存储实例
storage = await initialize_graph_storage()
@@ -407,5 +435,6 @@ async def main():
await storage.finalize()
ASCIIColors.green("\n存储连接已关闭")
if __name__ == "__main__":
asyncio.run(main())