Translate comment to English

This commit is contained in:
yangdx
2025-01-17 13:36:31 +08:00
parent 3138ae7599
commit 939e399dd4
2 changed files with 77 additions and 93 deletions

View File

@@ -1,12 +1,12 @@
"""
LightRAG Ollama 兼容接口测试脚本
LightRAG Ollama Compatibility Interface Test Script
这个脚本测试 LightRAG Ollama 兼容接口,包括:
1. 基本功能测试(流式和非流式响应)
2. 查询模式测试(localglobalnaivehybrid
3. 错误处理测试(包括流式和非流式场景)
This script tests the LightRAG's Ollama compatibility interface, including:
1. Basic functionality tests (streaming and non-streaming responses)
2. Query mode tests (local, global, naive, hybrid)
3. Error handling tests (including streaming and non-streaming scenarios)
所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。
All responses use the JSON Lines format, complying with the Ollama API specification.
"""
import requests
@@ -24,20 +24,10 @@ class OutputControl:
@classmethod
def set_verbose(cls, verbose: bool) -> None:
"""设置输出详细程度
Args:
verbose: True 为详细模式False 为静默模式
"""
cls._verbose = verbose
@classmethod
def is_verbose(cls) -> bool:
"""获取当前输出模式
Returns:
当前是否为详细模式
"""
return cls._verbose
@dataclass
@@ -48,9 +38,8 @@ class TestResult:
duration: float
error: Optional[str] = None
timestamp: str = ""
def __post_init__(self):
"""初始化后设置时间戳"""
if not self.timestamp:
self.timestamp = datetime.now().isoformat()
@@ -59,14 +48,13 @@ class TestStats:
def __init__(self):
self.results: List[TestResult] = []
self.start_time = datetime.now()
def add_result(self, result: TestResult):
"""添加测试结果"""
self.results.append(result)
def export_results(self, path: str = "test_results.json"):
"""导出测试结果到 JSON 文件
Args:
path: 输出文件路径
"""
@@ -81,25 +69,24 @@ class TestStats:
"total_duration": sum(r.duration for r in self.results)
}
}
with open(path, "w", encoding="utf-8") as f:
json.dump(results_data, f, ensure_ascii=False, indent=2)
print(f"\n测试结果已保存到: {path}")
def print_summary(self):
"""打印测试统计摘要"""
total = len(self.results)
passed = sum(1 for r in self.results if r.success)
failed = total - passed
duration = sum(r.duration for r in self.results)
print("\n=== 测试结果摘要 ===")
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"总用时: {duration:.2f}")
print(f"总计: {total} 个测试")
print(f"通过: {passed}")
print(f"失败: {failed}")
if failed > 0:
print("\n失败的测试:")
for result in self.results:
@@ -125,15 +112,15 @@ DEFAULT_CONFIG = {
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
"""发送 HTTP 请求,支持重试机制
Args:
url: 请求 URL
data: 请求数据
stream: 是否使用流式响应
Returns:
requests.Response 对象
requests.Response: 对象
Raises:
requests.exceptions.RequestException: 请求失败且重试次数用完
"""
@@ -141,7 +128,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
max_retries = server_config["max_retries"]
retry_delay = server_config["retry_delay"]
timeout = server_config["timeout"]
for attempt in range(max_retries):
try:
response = requests.post(
@@ -159,10 +146,10 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
def load_config() -> Dict[str, Any]:
"""加载配置文件
首先尝试从当前目录的 config.json 加载,
如果不存在则使用默认配置
Returns:
配置字典
"""
@@ -174,7 +161,7 @@ def load_config() -> Dict[str, Any]:
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
"""格式化打印 JSON 响应数据
Args:
data: 要打印的数据字典
title: 打印的标题
@@ -199,12 +186,12 @@ def create_request_data(
model: str = None
) -> Dict[str, Any]:
"""创建基本的请求数据
Args:
content: 用户消息内容
stream: 是否使用流式响应
model: 模型名称
Returns:
包含完整请求数据的字典
"""
@@ -224,7 +211,7 @@ STATS = TestStats()
def run_test(func: Callable, name: str) -> None:
"""运行测试并记录结果
Args:
func: 测试函数
name: 测试名称
@@ -246,21 +233,21 @@ def test_non_stream_chat():
CONFIG["test_cases"]["basic"]["query"],
stream=False
)
# 发送请求
response = make_request(url, data)
# 打印响应
if OutputControl.is_verbose():
print("\n=== 非流式调用响应 ===")
response_json = response.json()
# 打印响应内容
print_json_response({
"model": response_json["model"],
"message": response_json["message"]
}, "响应内容")
# # 打印性能统计
# print_json_response({
# "total_duration": response_json["total_duration"],
@@ -273,7 +260,7 @@ def test_non_stream_chat():
def test_stream_chat():
"""测试流式调用 /api/chat 接口
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
响应格式:
{
@@ -286,7 +273,7 @@ def test_stream_chat():
},
"done": false
}
最后一条消息会包含性能统计信息done 为 true。
"""
url = get_base_url()
@@ -294,10 +281,10 @@ def test_stream_chat():
CONFIG["test_cases"]["basic"]["query"],
stream=True
)
# 发送请求并获取流式响应
response = make_request(url, data, stream=True)
if OutputControl.is_verbose():
print("\n=== 流式调用响应 ===")
output_buffer = []
@@ -321,24 +308,24 @@ def test_stream_chat():
print("Error decoding JSON from response line")
finally:
response.close() # 确保关闭响应连接
# 打印一个换行
print()
def test_query_modes():
"""测试不同的查询模式前缀
支持的查询模式:
- /local: 本地检索模式,只在相关度高的文档中搜索
- /global: 全局检索模式,在所有文档中搜索
- /naive: 朴素模式,不使用任何优化策略
- /hybrid: 混合模式(默认),结合多种策略
每个模式都会返回相同格式的响应,但检索策略不同。
"""
url = get_base_url()
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
for mode in modes:
if OutputControl.is_verbose():
print(f"\n=== 测试 /{mode} 模式 ===")
@@ -346,11 +333,11 @@ def test_query_modes():
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
stream=False
)
# 发送请求
response = make_request(url, data)
response_json = response.json()
# 打印响应内容
print_json_response({
"model": response_json["model"],
@@ -359,13 +346,13 @@ def test_query_modes():
def create_error_test_data(error_type: str) -> Dict[str, Any]:
"""创建用于错误测试的请求数据
Args:
error_type: 错误类型,支持:
- empty_messages: 空消息列表
- invalid_role: 无效的角色字段
- missing_content: 缺少内容字段
Returns:
包含错误数据的请求字典
"""
@@ -399,19 +386,19 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
def test_stream_error_handling():
"""测试流式响应的错误处理
测试场景:
1. 空消息列表
2. 消息格式错误(缺少必需字段)
错误响应会立即返回,不会建立流式连接。
状态码应该是 4xx并返回详细的错误信息。
"""
url = get_base_url()
if OutputControl.is_verbose():
print("\n=== 测试流式响应错误处理 ===")
# 测试空消息列表
if OutputControl.is_verbose():
print("\n--- 测试空消息列表(流式)---")
@@ -421,7 +408,7 @@ def test_stream_error_handling():
if response.status_code != 200:
print_json_response(response.json(), "错误信息")
response.close()
# 测试无效角色字段
if OutputControl.is_verbose():
print("\n--- 测试无效角色字段(流式)---")
@@ -444,23 +431,23 @@ def test_stream_error_handling():
def test_error_handling():
"""测试非流式响应的错误处理
测试场景:
1. 空消息列表
2. 消息格式错误(缺少必需字段)
错误响应格式:
{
"detail": "错误描述"
}
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
"""
url = get_base_url()
if OutputControl.is_verbose():
print("\n=== 测试错误处理 ===")
# 测试空消息列表
if OutputControl.is_verbose():
print("\n--- 测试空消息列表 ---")
@@ -469,7 +456,7 @@ def test_error_handling():
response = make_request(url, data)
print(f"状态码: {response.status_code}")
print_json_response(response.json(), "错误信息")
# 测试无效角色字段
if OutputControl.is_verbose():
print("\n--- 测试无效角色字段 ---")
@@ -490,7 +477,7 @@ def test_error_handling():
def get_test_cases() -> Dict[str, Callable]:
"""获取所有可用的测试用例
Returns:
测试名称到测试函数的映射字典
"""
@@ -564,21 +551,21 @@ def parse_args() -> argparse.Namespace:
if __name__ == "__main__":
args = parse_args()
# 设置输出模式
OutputControl.set_verbose(not args.quiet)
# 如果指定了查询内容,更新配置
if args.ask:
CONFIG["test_cases"]["basic"]["query"] = args.ask
# 如果指定了创建配置文件
if args.init_config:
create_default_config()
exit(0)
test_cases = get_test_cases()
try:
if "all" in args.tests:
# 运行所有测试
@@ -586,11 +573,11 @@ if __name__ == "__main__":
print("\n【基本功能测试】")
run_test(test_non_stream_chat, "非流式调用测试")
run_test(test_stream_chat, "流式调用测试")
if OutputControl.is_verbose():
print("\n【查询模式测试】")
run_test(test_query_modes, "查询模式测试")
if OutputControl.is_verbose():
print("\n【错误处理测试】")
run_test(test_error_handling, "错误处理测试")