Translate unit test comment and promts to English

This commit is contained in:
yangdx
2025-01-17 14:07:17 +08:00
parent 939e399dd4
commit 48f70ff8b4

View File

@@ -19,7 +19,7 @@ from datetime import datetime
from pathlib import Path
class OutputControl:
"""输出控制类,管理测试输出的详细程度"""
"""Output control class, manages the verbosity of test output"""
_verbose: bool = False
@classmethod
@@ -32,7 +32,7 @@ class OutputControl:
@dataclass
class TestResult:
"""测试结果数据类"""
"""Test result data class"""
name: str
success: bool
duration: float
@@ -44,7 +44,7 @@ class TestResult:
self.timestamp = datetime.now().isoformat()
class TestStats:
"""测试统计信息"""
"""Test statistics"""
def __init__(self):
self.results: List[TestResult] = []
self.start_time = datetime.now()
@@ -53,10 +53,9 @@ class TestStats:
self.results.append(result)
def export_results(self, path: str = "test_results.json"):
"""导出测试结果到 JSON 文件
"""Export test results to a JSON file
Args:
path: 输出文件路径
path: Output file path
"""
results_data = {
"start_time": self.start_time.isoformat(),
@@ -72,7 +71,7 @@ class TestStats:
with open(path, "w", encoding="utf-8") as f:
json.dump(results_data, f, ensure_ascii=False, indent=2)
print(f"\n测试结果已保存到: {path}")
print(f"\nTest results saved to: {path}")
def print_summary(self):
total = len(self.results)
@@ -80,28 +79,27 @@ class TestStats:
failed = total - passed
duration = sum(r.duration for r in self.results)
print("\n=== 测试结果摘要 ===")
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"总用时: {duration:.2f}")
print(f"总计: {total} 个测试")
print(f"通过: {passed}")
print(f"失败: {failed}")
print("\n=== Test Summary ===")
print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total duration: {duration:.2f} seconds")
print(f"Total tests: {total}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
if failed > 0:
print("\n失败的测试:")
print("\nFailed tests:")
for result in self.results:
if not result.success:
print(f"- {result.name}: {result.error}")
# 默认配置
DEFAULT_CONFIG = {
"server": {
"host": "localhost",
"port": 9621,
"model": "lightrag:latest",
"timeout": 30, # 请求超时时间(秒)
"max_retries": 3, # 最大重试次数
"retry_delay": 1 # 重试间隔(秒)
"timeout": 30,
"max_retries": 3,
"retry_delay": 1
},
"test_cases": {
"basic": {
@@ -111,18 +109,16 @@ DEFAULT_CONFIG = {
}
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
"""发送 HTTP 请求,支持重试机制
"""Send an HTTP request with retry mechanism
Args:
url: 请求 URL
data: 请求数据
stream: 是否使用流式响应
url: Request URL
data: Request data
stream: Whether to use streaming response
Returns:
requests.Response: 对象
requests.Response: Response object
Raises:
requests.exceptions.RequestException: 请求失败且重试次数用完
requests.exceptions.RequestException: Request failed after all retries
"""
server_config = CONFIG["server"]
max_retries = server_config["max_retries"]
@@ -139,19 +135,18 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
)
return response
except requests.exceptions.RequestException as e:
if attempt == max_retries - 1: # 最后一次重试
if attempt == max_retries - 1: # Last retry
raise
print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}")
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
time.sleep(retry_delay)
def load_config() -> Dict[str, Any]:
"""加载配置文件
首先尝试从当前目录的 config.json 加载,
如果不存在则使用默认配置
"""Load configuration file
First try to load from config.json in the current directory,
if it doesn't exist, use the default configuration
Returns:
配置字典
Configuration dictionary
"""
config_path = Path("config.json")
if config_path.exists():
@@ -160,23 +155,22 @@ def load_config() -> Dict[str, Any]:
return DEFAULT_CONFIG
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
"""格式化打印 JSON 响应数据
"""Format and print JSON response data
Args:
data: 要打印的数据字典
title: 打印的标题
indent: JSON 缩进空格数
data: Data dictionary to print
title: Title to print
indent: Number of spaces for JSON indentation
"""
if OutputControl.is_verbose():
if title:
print(f"\n=== {title} ===")
print(json.dumps(data, ensure_ascii=False, indent=indent))
# 全局配置
# Global configuration
CONFIG = load_config()
def get_base_url() -> str:
"""返回基础 URL"""
"""Return the base URL"""
server = CONFIG["server"]
return f"http://{server['host']}:{server['port']}/api/chat"
@@ -185,15 +179,13 @@ def create_request_data(
stream: bool = False,
model: str = None
) -> Dict[str, Any]:
"""创建基本的请求数据
"""Create basic request data
Args:
content: 用户消息内容
stream: 是否使用流式响应
model: 模型名称
content: User message content
stream: Whether to use streaming response
model: Model name
Returns:
包含完整请求数据的字典
Dictionary containing complete request data
"""
return {
"model": model or CONFIG["server"]["model"],
@@ -206,15 +198,14 @@ def create_request_data(
"stream": stream
}
# 全局测试统计
# Global test statistics
STATS = TestStats()
def run_test(func: Callable, name: str) -> None:
"""运行测试并记录结果
"""Run a test and record the results
Args:
func: 测试函数
name: 测试名称
func: Test function
name: Test name
"""
start_time = time.time()
try:
@@ -227,54 +218,43 @@ def run_test(func: Callable, name: str) -> None:
raise
def test_non_stream_chat():
"""测试非流式调用 /api/chat 接口"""
"""Test non-streaming call to /api/chat endpoint"""
url = get_base_url()
data = create_request_data(
CONFIG["test_cases"]["basic"]["query"],
stream=False
)
# 发送请求
# Send request
response = make_request(url, data)
# 打印响应
# Print response
if OutputControl.is_verbose():
print("\n=== 非流式调用响应 ===")
print("\n=== Non-streaming call response ===")
response_json = response.json()
# 打印响应内容
# Print response content
print_json_response({
"model": response_json["model"],
"message": response_json["message"]
}, "响应内容")
# # 打印性能统计
# print_json_response({
# "total_duration": response_json["total_duration"],
# "load_duration": response_json["load_duration"],
# "prompt_eval_count": response_json["prompt_eval_count"],
# "prompt_eval_duration": response_json["prompt_eval_duration"],
# "eval_count": response_json["eval_count"],
# "eval_duration": response_json["eval_duration"]
# }, "性能统计")
}, "Response content")
def test_stream_chat():
"""测试流式调用 /api/chat 接口
"""Test streaming call to /api/chat endpoint
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
响应格式:
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
Response format:
{
"model": "lightrag:latest",
"created_at": "2024-01-15T00:00:00Z",
"message": {
"role": "assistant",
"content": "部分响应内容",
"content": "Partial response content",
"images": null
},
"done": false
}
最后一条消息会包含性能统计信息done 为 true
The last message will contain performance statistics, with done set to true.
"""
url = get_base_url()
data = create_request_data(
@@ -282,79 +262,79 @@ def test_stream_chat():
stream=True
)
# 发送请求并获取流式响应
# Send request and get streaming response
response = make_request(url, data, stream=True)
if OutputControl.is_verbose():
print("\n=== 流式调用响应 ===")
print("\n=== Streaming call response ===")
output_buffer = []
try:
for line in response.iter_lines():
if line: # 跳过空行
if line: # Skip empty lines
try:
# 解码并解析 JSON
# Decode and parse JSON
data = json.loads(line.decode('utf-8'))
if data.get("done", True): # 如果是完成标记
if "total_duration" in data: # 最终的性能统计消息
# print_json_response(data, "性能统计")
if data.get("done", True): # If it's the completion marker
if "total_duration" in data: # Final performance statistics message
# print_json_response(data, "Performance statistics")
break
else: # 正常的内容消息
else: # Normal content message
message = data.get("message", {})
content = message.get("content", "")
if content: # 只收集非空内容
if content: # Only collect non-empty content
output_buffer.append(content)
print(content, end="", flush=True) # 实时打印内容
print(content, end="", flush=True) # Print content in real-time
except json.JSONDecodeError:
print("Error decoding JSON from response line")
finally:
response.close() # 确保关闭响应连接
response.close() # Ensure the response connection is closed
# 打印一个换行
# Print a newline
print()
def test_query_modes():
"""测试不同的查询模式前缀
"""Test different query mode prefixes
支持的查询模式:
- /local: 本地检索模式,只在相关度高的文档中搜索
- /global: 全局检索模式,在所有文档中搜索
- /naive: 朴素模式,不使用任何优化策略
- /hybrid: 混合模式(默认),结合多种策略
Supported query modes:
- /local: Local retrieval mode, searches only in highly relevant documents
- /global: Global retrieval mode, searches across all documents
- /naive: Naive mode, does not use any optimization strategies
- /hybrid: Hybrid mode (default), combines multiple strategies
- /mix: Mix mode
每个模式都会返回相同格式的响应,但检索策略不同。
Each mode will return responses in the same format, but with different retrieval strategies.
"""
url = get_base_url()
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
modes = ["local", "global", "naive", "hybrid", "mix"]
for mode in modes:
if OutputControl.is_verbose():
print(f"\n=== 测试 /{mode} 模式 ===")
print(f"\n=== Testing /{mode} mode ===")
data = create_request_data(
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
stream=False
)
# 发送请求
# Send request
response = make_request(url, data)
response_json = response.json()
# 打印响应内容
# Print response content
print_json_response({
"model": response_json["model"],
"message": response_json["message"]
})
def create_error_test_data(error_type: str) -> Dict[str, Any]:
"""创建用于错误测试的请求数据
"""Create request data for error testing
Args:
error_type: 错误类型,支持:
- empty_messages: 空消息列表
- invalid_role: 无效的角色字段
- missing_content: 缺少内容字段
error_type: Error type, supported:
- empty_messages: Empty message list
- invalid_role: Invalid role field
- missing_content: Missing content field
Returns:
包含错误数据的请求字典
Request dictionary containing error data
"""
error_data = {
"empty_messages": {
@@ -367,7 +347,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
"messages": [
{
"invalid_role": "user",
"content": "测试消息"
"content": "Test message"
}
],
"stream": True
@@ -385,101 +365,100 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
return error_data.get(error_type, error_data["empty_messages"])
def test_stream_error_handling():
"""测试流式响应的错误处理
"""Test error handling for streaming responses
测试场景:
1. 空消息列表
2. 消息格式错误(缺少必需字段)
Test scenarios:
1. Empty message list
2. Message format error (missing required fields)
错误响应会立即返回,不会建立流式连接。
状态码应该是 4xx并返回详细的错误信息。
Error responses should be returned immediately without establishing a streaming connection.
The status code should be 4xx, and detailed error information should be returned.
"""
url = get_base_url()
if OutputControl.is_verbose():
print("\n=== 测试流式响应错误处理 ===")
print("\n=== Testing streaming response error handling ===")
# 测试空消息列表
# Test empty message list
if OutputControl.is_verbose():
print("\n--- 测试空消息列表(流式)---")
print("\n--- Testing empty message list (streaming) ---")
data = create_error_test_data("empty_messages")
response = make_request(url, data, stream=True)
print(f"状态码: {response.status_code}")
print(f"Status code: {response.status_code}")
if response.status_code != 200:
print_json_response(response.json(), "错误信息")
print_json_response(response.json(), "Error message")
response.close()
# 测试无效角色字段
# Test invalid role field
if OutputControl.is_verbose():
print("\n--- 测试无效角色字段(流式)---")
print("\n--- Testing invalid role field (streaming) ---")
data = create_error_test_data("invalid_role")
response = make_request(url, data, stream=True)
print(f"状态码: {response.status_code}")
print(f"Status code: {response.status_code}")
if response.status_code != 200:
print_json_response(response.json(), "错误信息")
print_json_response(response.json(), "Error message")
response.close()
# 测试缺少内容字段
# Test missing content field
if OutputControl.is_verbose():
print("\n--- 测试缺少内容字段(流式)---")
print("\n--- Testing missing content field (streaming) ---")
data = create_error_test_data("missing_content")
response = make_request(url, data, stream=True)
print(f"状态码: {response.status_code}")
print(f"Status code: {response.status_code}")
if response.status_code != 200:
print_json_response(response.json(), "错误信息")
print_json_response(response.json(), "Error message")
response.close()
def test_error_handling():
"""测试非流式响应的错误处理
"""Test error handling for non-streaming responses
测试场景:
1. 空消息列表
2. 消息格式错误(缺少必需字段)
Test scenarios:
1. Empty message list
2. Message format error (missing required fields)
错误响应格式:
Error response format:
{
"detail": "错误描述"
"detail": "Error description"
}
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
All errors should return appropriate HTTP status codes and clear error messages.
"""
url = get_base_url()
if OutputControl.is_verbose():
print("\n=== 测试错误处理 ===")
print("\n=== Testing error handling ===")
# 测试空消息列表
# Test empty message list
if OutputControl.is_verbose():
print("\n--- 测试空消息列表 ---")
print("\n--- Testing empty message list ---")
data = create_error_test_data("empty_messages")
data["stream"] = False # 修改为非流式模式
data["stream"] = False # Change to non-streaming mode
response = make_request(url, data)
print(f"状态码: {response.status_code}")
print_json_response(response.json(), "错误信息")
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# 测试无效角色字段
# Test invalid role field
if OutputControl.is_verbose():
print("\n--- 测试无效角色字段 ---")
print("\n--- Testing invalid role field ---")
data = create_error_test_data("invalid_role")
data["stream"] = False # 修改为非流式模式
data["stream"] = False # Change to non-streaming mode
response = make_request(url, data)
print(f"状态码: {response.status_code}")
print_json_response(response.json(), "错误信息")
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# 测试缺少内容字段
# Test missing content field
if OutputControl.is_verbose():
print("\n--- 测试缺少内容字段 ---")
print("\n--- Testing missing content field ---")
data = create_error_test_data("missing_content")
data["stream"] = False # 修改为非流式模式
data["stream"] = False # Change to non-streaming mode
response = make_request(url, data)
print(f"状态码: {response.status_code}")
print_json_response(response.json(), "错误信息")
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
def get_test_cases() -> Dict[str, Callable]:
"""获取所有可用的测试用例
"""Get all available test cases
Returns:
测试名称到测试函数的映射字典
A dictionary mapping test names to test functions
"""
return {
"non_stream": test_non_stream_chat,
@@ -490,30 +469,30 @@ def get_test_cases() -> Dict[str, Callable]:
}
def create_default_config():
"""创建默认配置文件"""
"""Create a default configuration file"""
config_path = Path("config.json")
if not config_path.exists():
with open(config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
print(f"已创建默认配置文件: {config_path}")
print(f"Default configuration file created: {config_path}")
def parse_args() -> argparse.Namespace:
"""解析命令行参数"""
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="LightRAG Ollama 兼容接口测试",
description="LightRAG Ollama Compatibility Interface Testing",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
配置文件 (config.json):
Configuration file (config.json):
{
"server": {
"host": "localhost", # 服务器地址
"port": 9621, # 服务器端口
"model": "lightrag:latest" # 默认模型名称
"host": "localhost", # Server address
"port": 9621, # Server port
"model": "lightrag:latest" # Default model name
},
"test_cases": {
"basic": {
"query": "测试查询", # 基本查询文本
"stream_query": "流式查询" # 流式查询文本
"query": "Test query", # Basic query text
"stream_query": "Stream query" # Stream query text
}
}
}
@@ -522,44 +501,44 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"-q", "--quiet",
action="store_true",
help="静默模式,只显示测试结果摘要"
help="Silent mode, only display test result summary"
)
parser.add_argument(
"-a", "--ask",
type=str,
help="指定查询内容,会覆盖配置文件中的查询设置"
help="Specify query content, which will override the query settings in the configuration file"
)
parser.add_argument(
"--init-config",
action="store_true",
help="创建默认配置文件"
help="Create default configuration file"
)
parser.add_argument(
"--output",
type=str,
default="",
help="测试结果输出文件路径,默认不输出到文件"
help="Test result output file path, default is not to output to a file"
)
parser.add_argument(
"--tests",
nargs="+",
choices=list(get_test_cases().keys()) + ["all"],
default=["all"],
help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试"
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# 设置输出模式
# Set output mode
OutputControl.set_verbose(not args.quiet)
# 如果指定了查询内容,更新配置
# If query content is specified, update the configuration
if args.ask:
CONFIG["test_cases"]["basic"]["query"] = args.ask
# 如果指定了创建配置文件
# If specified to create a configuration file
if args.init_config:
create_default_config()
exit(0)
@@ -568,31 +547,31 @@ if __name__ == "__main__":
try:
if "all" in args.tests:
# 运行所有测试
# Run all tests
if OutputControl.is_verbose():
print("\n基本功能测试")
run_test(test_non_stream_chat, "非流式调用测试")
run_test(test_stream_chat, "流式调用测试")
print("\nBasic Functionality Tests")
run_test(test_non_stream_chat, "Non-streaming Call Test")
run_test(test_stream_chat, "Streaming Call Test")
if OutputControl.is_verbose():
print("\n查询模式测试")
run_test(test_query_modes, "查询模式测试")
print("\nQuery Mode Tests")
run_test(test_query_modes, "Query Mode Test")
if OutputControl.is_verbose():
print("\n错误处理测试")
run_test(test_error_handling, "错误处理测试")
run_test(test_stream_error_handling, "流式错误处理测试")
print("\nError Handling Tests")
run_test(test_error_handling, "Error Handling Test")
run_test(test_stream_error_handling, "Streaming Error Handling Test")
else:
# 运行指定的测试
# Run specified tests
for test_name in args.tests:
if OutputControl.is_verbose():
print(f"\n运行测试: {test_name}")
print(f"\nRunning Test: {test_name}")
run_test(test_cases[test_name], test_name)
except Exception as e:
print(f"\n发生错误: {str(e)}")
print(f"\nAn error occurred: {str(e)}")
finally:
# 打印测试统计
# Print test statistics
STATS.print_summary()
# 如果指定了输出文件路径,则导出结果
# If an output file path is specified, export the results
if args.output:
STATS.export_results(args.output)