Translate unit test comment and promts to English
This commit is contained in:
@@ -19,7 +19,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
class OutputControl:
|
||||
"""输出控制类,管理测试输出的详细程度"""
|
||||
"""Output control class, manages the verbosity of test output"""
|
||||
_verbose: bool = False
|
||||
|
||||
@classmethod
|
||||
@@ -32,7 +32,7 @@ class OutputControl:
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
"""测试结果数据类"""
|
||||
"""Test result data class"""
|
||||
name: str
|
||||
success: bool
|
||||
duration: float
|
||||
@@ -44,7 +44,7 @@ class TestResult:
|
||||
self.timestamp = datetime.now().isoformat()
|
||||
|
||||
class TestStats:
|
||||
"""测试统计信息"""
|
||||
"""Test statistics"""
|
||||
def __init__(self):
|
||||
self.results: List[TestResult] = []
|
||||
self.start_time = datetime.now()
|
||||
@@ -53,10 +53,9 @@ class TestStats:
|
||||
self.results.append(result)
|
||||
|
||||
def export_results(self, path: str = "test_results.json"):
|
||||
"""导出测试结果到 JSON 文件
|
||||
|
||||
"""Export test results to a JSON file
|
||||
Args:
|
||||
path: 输出文件路径
|
||||
path: Output file path
|
||||
"""
|
||||
results_data = {
|
||||
"start_time": self.start_time.isoformat(),
|
||||
@@ -72,7 +71,7 @@ class TestStats:
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n测试结果已保存到: {path}")
|
||||
print(f"\nTest results saved to: {path}")
|
||||
|
||||
def print_summary(self):
|
||||
total = len(self.results)
|
||||
@@ -80,28 +79,27 @@ class TestStats:
|
||||
failed = total - passed
|
||||
duration = sum(r.duration for r in self.results)
|
||||
|
||||
print("\n=== 测试结果摘要 ===")
|
||||
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"总用时: {duration:.2f}秒")
|
||||
print(f"总计: {total} 个测试")
|
||||
print(f"通过: {passed} 个")
|
||||
print(f"失败: {failed} 个")
|
||||
print("\n=== Test Summary ===")
|
||||
print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"Total duration: {duration:.2f} seconds")
|
||||
print(f"Total tests: {total}")
|
||||
print(f"Passed: {passed}")
|
||||
print(f"Failed: {failed}")
|
||||
|
||||
if failed > 0:
|
||||
print("\n失败的测试:")
|
||||
print("\nFailed tests:")
|
||||
for result in self.results:
|
||||
if not result.success:
|
||||
print(f"- {result.name}: {result.error}")
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"server": {
|
||||
"host": "localhost",
|
||||
"port": 9621,
|
||||
"model": "lightrag:latest",
|
||||
"timeout": 30, # 请求超时时间(秒)
|
||||
"max_retries": 3, # 最大重试次数
|
||||
"retry_delay": 1 # 重试间隔(秒)
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1
|
||||
},
|
||||
"test_cases": {
|
||||
"basic": {
|
||||
@@ -111,18 +109,16 @@ DEFAULT_CONFIG = {
|
||||
}
|
||||
|
||||
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
||||
"""发送 HTTP 请求,支持重试机制
|
||||
|
||||
"""Send an HTTP request with retry mechanism
|
||||
Args:
|
||||
url: 请求 URL
|
||||
data: 请求数据
|
||||
stream: 是否使用流式响应
|
||||
|
||||
url: Request URL
|
||||
data: Request data
|
||||
stream: Whether to use streaming response
|
||||
Returns:
|
||||
requests.Response: 对象
|
||||
requests.Response: Response object
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: 请求失败且重试次数用完
|
||||
requests.exceptions.RequestException: Request failed after all retries
|
||||
"""
|
||||
server_config = CONFIG["server"]
|
||||
max_retries = server_config["max_retries"]
|
||||
@@ -139,19 +135,18 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
||||
)
|
||||
return response
|
||||
except requests.exceptions.RequestException as e:
|
||||
if attempt == max_retries - 1: # 最后一次重试
|
||||
if attempt == max_retries - 1: # Last retry
|
||||
raise
|
||||
print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}")
|
||||
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
|
||||
time.sleep(retry_delay)
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
"""加载配置文件
|
||||
|
||||
首先尝试从当前目录的 config.json 加载,
|
||||
如果不存在则使用默认配置
|
||||
"""Load configuration file
|
||||
|
||||
First try to load from config.json in the current directory,
|
||||
if it doesn't exist, use the default configuration
|
||||
Returns:
|
||||
配置字典
|
||||
Configuration dictionary
|
||||
"""
|
||||
config_path = Path("config.json")
|
||||
if config_path.exists():
|
||||
@@ -160,23 +155,22 @@ def load_config() -> Dict[str, Any]:
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
||||
"""格式化打印 JSON 响应数据
|
||||
|
||||
"""Format and print JSON response data
|
||||
Args:
|
||||
data: 要打印的数据字典
|
||||
title: 打印的标题
|
||||
indent: JSON 缩进空格数
|
||||
data: Data dictionary to print
|
||||
title: Title to print
|
||||
indent: Number of spaces for JSON indentation
|
||||
"""
|
||||
if OutputControl.is_verbose():
|
||||
if title:
|
||||
print(f"\n=== {title} ===")
|
||||
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
||||
|
||||
# 全局配置
|
||||
# Global configuration
|
||||
CONFIG = load_config()
|
||||
|
||||
def get_base_url() -> str:
|
||||
"""返回基础 URL"""
|
||||
"""Return the base URL"""
|
||||
server = CONFIG["server"]
|
||||
return f"http://{server['host']}:{server['port']}/api/chat"
|
||||
|
||||
@@ -185,15 +179,13 @@ def create_request_data(
|
||||
stream: bool = False,
|
||||
model: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""创建基本的请求数据
|
||||
|
||||
"""Create basic request data
|
||||
Args:
|
||||
content: 用户消息内容
|
||||
stream: 是否使用流式响应
|
||||
model: 模型名称
|
||||
|
||||
content: User message content
|
||||
stream: Whether to use streaming response
|
||||
model: Model name
|
||||
Returns:
|
||||
包含完整请求数据的字典
|
||||
Dictionary containing complete request data
|
||||
"""
|
||||
return {
|
||||
"model": model or CONFIG["server"]["model"],
|
||||
@@ -206,15 +198,14 @@ def create_request_data(
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# 全局测试统计
|
||||
# Global test statistics
|
||||
STATS = TestStats()
|
||||
|
||||
def run_test(func: Callable, name: str) -> None:
|
||||
"""运行测试并记录结果
|
||||
|
||||
"""Run a test and record the results
|
||||
Args:
|
||||
func: 测试函数
|
||||
name: 测试名称
|
||||
func: Test function
|
||||
name: Test name
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
@@ -227,54 +218,43 @@ def run_test(func: Callable, name: str) -> None:
|
||||
raise
|
||||
|
||||
def test_non_stream_chat():
|
||||
"""测试非流式调用 /api/chat 接口"""
|
||||
"""Test non-streaming call to /api/chat endpoint"""
|
||||
url = get_base_url()
|
||||
data = create_request_data(
|
||||
CONFIG["test_cases"]["basic"]["query"],
|
||||
stream=False
|
||||
)
|
||||
|
||||
# 发送请求
|
||||
# Send request
|
||||
response = make_request(url, data)
|
||||
|
||||
# 打印响应
|
||||
# Print response
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== 非流式调用响应 ===")
|
||||
print("\n=== Non-streaming call response ===")
|
||||
response_json = response.json()
|
||||
|
||||
# 打印响应内容
|
||||
# Print response content
|
||||
print_json_response({
|
||||
"model": response_json["model"],
|
||||
"message": response_json["message"]
|
||||
}, "响应内容")
|
||||
|
||||
# # 打印性能统计
|
||||
# print_json_response({
|
||||
# "total_duration": response_json["total_duration"],
|
||||
# "load_duration": response_json["load_duration"],
|
||||
# "prompt_eval_count": response_json["prompt_eval_count"],
|
||||
# "prompt_eval_duration": response_json["prompt_eval_duration"],
|
||||
# "eval_count": response_json["eval_count"],
|
||||
# "eval_duration": response_json["eval_duration"]
|
||||
# }, "性能统计")
|
||||
|
||||
}, "Response content")
|
||||
def test_stream_chat():
|
||||
"""测试流式调用 /api/chat 接口
|
||||
"""Test streaming call to /api/chat endpoint
|
||||
|
||||
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
|
||||
响应格式:
|
||||
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
||||
Response format:
|
||||
{
|
||||
"model": "lightrag:latest",
|
||||
"created_at": "2024-01-15T00:00:00Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "部分响应内容",
|
||||
"content": "Partial response content",
|
||||
"images": null
|
||||
},
|
||||
"done": false
|
||||
}
|
||||
|
||||
最后一条消息会包含性能统计信息,done 为 true。
|
||||
The last message will contain performance statistics, with done set to true.
|
||||
"""
|
||||
url = get_base_url()
|
||||
data = create_request_data(
|
||||
@@ -282,79 +262,79 @@ def test_stream_chat():
|
||||
stream=True
|
||||
)
|
||||
|
||||
# 发送请求并获取流式响应
|
||||
# Send request and get streaming response
|
||||
response = make_request(url, data, stream=True)
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== 流式调用响应 ===")
|
||||
print("\n=== Streaming call response ===")
|
||||
output_buffer = []
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
if line: # 跳过空行
|
||||
if line: # Skip empty lines
|
||||
try:
|
||||
# 解码并解析 JSON
|
||||
# Decode and parse JSON
|
||||
data = json.loads(line.decode('utf-8'))
|
||||
if data.get("done", True): # 如果是完成标记
|
||||
if "total_duration" in data: # 最终的性能统计消息
|
||||
# print_json_response(data, "性能统计")
|
||||
if data.get("done", True): # If it's the completion marker
|
||||
if "total_duration" in data: # Final performance statistics message
|
||||
# print_json_response(data, "Performance statistics")
|
||||
break
|
||||
else: # 正常的内容消息
|
||||
else: # Normal content message
|
||||
message = data.get("message", {})
|
||||
content = message.get("content", "")
|
||||
if content: # 只收集非空内容
|
||||
if content: # Only collect non-empty content
|
||||
output_buffer.append(content)
|
||||
print(content, end="", flush=True) # 实时打印内容
|
||||
print(content, end="", flush=True) # Print content in real-time
|
||||
except json.JSONDecodeError:
|
||||
print("Error decoding JSON from response line")
|
||||
finally:
|
||||
response.close() # 确保关闭响应连接
|
||||
response.close() # Ensure the response connection is closed
|
||||
|
||||
# 打印一个换行
|
||||
# Print a newline
|
||||
print()
|
||||
|
||||
def test_query_modes():
|
||||
"""测试不同的查询模式前缀
|
||||
"""Test different query mode prefixes
|
||||
|
||||
支持的查询模式:
|
||||
- /local: 本地检索模式,只在相关度高的文档中搜索
|
||||
- /global: 全局检索模式,在所有文档中搜索
|
||||
- /naive: 朴素模式,不使用任何优化策略
|
||||
- /hybrid: 混合模式(默认),结合多种策略
|
||||
Supported query modes:
|
||||
- /local: Local retrieval mode, searches only in highly relevant documents
|
||||
- /global: Global retrieval mode, searches across all documents
|
||||
- /naive: Naive mode, does not use any optimization strategies
|
||||
- /hybrid: Hybrid mode (default), combines multiple strategies
|
||||
- /mix: Mix mode
|
||||
|
||||
每个模式都会返回相同格式的响应,但检索策略不同。
|
||||
Each mode will return responses in the same format, but with different retrieval strategies.
|
||||
"""
|
||||
url = get_base_url()
|
||||
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
|
||||
modes = ["local", "global", "naive", "hybrid", "mix"]
|
||||
|
||||
for mode in modes:
|
||||
if OutputControl.is_verbose():
|
||||
print(f"\n=== 测试 /{mode} 模式 ===")
|
||||
print(f"\n=== Testing /{mode} mode ===")
|
||||
data = create_request_data(
|
||||
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
||||
stream=False
|
||||
)
|
||||
|
||||
# 发送请求
|
||||
# Send request
|
||||
response = make_request(url, data)
|
||||
response_json = response.json()
|
||||
|
||||
# 打印响应内容
|
||||
# Print response content
|
||||
print_json_response({
|
||||
"model": response_json["model"],
|
||||
"message": response_json["message"]
|
||||
})
|
||||
|
||||
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||
"""创建用于错误测试的请求数据
|
||||
|
||||
"""Create request data for error testing
|
||||
Args:
|
||||
error_type: 错误类型,支持:
|
||||
- empty_messages: 空消息列表
|
||||
- invalid_role: 无效的角色字段
|
||||
- missing_content: 缺少内容字段
|
||||
error_type: Error type, supported:
|
||||
- empty_messages: Empty message list
|
||||
- invalid_role: Invalid role field
|
||||
- missing_content: Missing content field
|
||||
|
||||
Returns:
|
||||
包含错误数据的请求字典
|
||||
Request dictionary containing error data
|
||||
"""
|
||||
error_data = {
|
||||
"empty_messages": {
|
||||
@@ -367,7 +347,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||
"messages": [
|
||||
{
|
||||
"invalid_role": "user",
|
||||
"content": "测试消息"
|
||||
"content": "Test message"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
@@ -385,101 +365,100 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||
return error_data.get(error_type, error_data["empty_messages"])
|
||||
|
||||
def test_stream_error_handling():
|
||||
"""测试流式响应的错误处理
|
||||
"""Test error handling for streaming responses
|
||||
|
||||
测试场景:
|
||||
1. 空消息列表
|
||||
2. 消息格式错误(缺少必需字段)
|
||||
Test scenarios:
|
||||
1. Empty message list
|
||||
2. Message format error (missing required fields)
|
||||
|
||||
错误响应会立即返回,不会建立流式连接。
|
||||
状态码应该是 4xx,并返回详细的错误信息。
|
||||
Error responses should be returned immediately without establishing a streaming connection.
|
||||
The status code should be 4xx, and detailed error information should be returned.
|
||||
"""
|
||||
url = get_base_url()
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== 测试流式响应错误处理 ===")
|
||||
print("\n=== Testing streaming response error handling ===")
|
||||
|
||||
# 测试空消息列表
|
||||
# Test empty message list
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- 测试空消息列表(流式)---")
|
||||
print("\n--- Testing empty message list (streaming) ---")
|
||||
data = create_error_test_data("empty_messages")
|
||||
response = make_request(url, data, stream=True)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"Status code: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
print_json_response(response.json(), "错误信息")
|
||||
print_json_response(response.json(), "Error message")
|
||||
response.close()
|
||||
|
||||
# 测试无效角色字段
|
||||
# Test invalid role field
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- 测试无效角色字段(流式)---")
|
||||
print("\n--- Testing invalid role field (streaming) ---")
|
||||
data = create_error_test_data("invalid_role")
|
||||
response = make_request(url, data, stream=True)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"Status code: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
print_json_response(response.json(), "错误信息")
|
||||
print_json_response(response.json(), "Error message")
|
||||
response.close()
|
||||
|
||||
# 测试缺少内容字段
|
||||
# Test missing content field
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- 测试缺少内容字段(流式)---")
|
||||
print("\n--- Testing missing content field (streaming) ---")
|
||||
data = create_error_test_data("missing_content")
|
||||
response = make_request(url, data, stream=True)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"Status code: {response.status_code}")
|
||||
if response.status_code != 200:
|
||||
print_json_response(response.json(), "错误信息")
|
||||
print_json_response(response.json(), "Error message")
|
||||
response.close()
|
||||
|
||||
def test_error_handling():
|
||||
"""测试非流式响应的错误处理
|
||||
"""Test error handling for non-streaming responses
|
||||
|
||||
测试场景:
|
||||
1. 空消息列表
|
||||
2. 消息格式错误(缺少必需字段)
|
||||
Test scenarios:
|
||||
1. Empty message list
|
||||
2. Message format error (missing required fields)
|
||||
|
||||
错误响应格式:
|
||||
Error response format:
|
||||
{
|
||||
"detail": "错误描述"
|
||||
"detail": "Error description"
|
||||
}
|
||||
|
||||
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
|
||||
All errors should return appropriate HTTP status codes and clear error messages.
|
||||
"""
|
||||
url = get_base_url()
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== 测试错误处理 ===")
|
||||
print("\n=== Testing error handling ===")
|
||||
|
||||
# 测试空消息列表
|
||||
# Test empty message list
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- 测试空消息列表 ---")
|
||||
print("\n--- Testing empty message list ---")
|
||||
data = create_error_test_data("empty_messages")
|
||||
data["stream"] = False # 修改为非流式模式
|
||||
data["stream"] = False # Change to non-streaming mode
|
||||
response = make_request(url, data)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print_json_response(response.json(), "错误信息")
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
# 测试无效角色字段
|
||||
# Test invalid role field
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- 测试无效角色字段 ---")
|
||||
print("\n--- Testing invalid role field ---")
|
||||
data = create_error_test_data("invalid_role")
|
||||
data["stream"] = False # 修改为非流式模式
|
||||
data["stream"] = False # Change to non-streaming mode
|
||||
response = make_request(url, data)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print_json_response(response.json(), "错误信息")
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
# 测试缺少内容字段
|
||||
# Test missing content field
|
||||
if OutputControl.is_verbose():
|
||||
print("\n--- 测试缺少内容字段 ---")
|
||||
print("\n--- Testing missing content field ---")
|
||||
data = create_error_test_data("missing_content")
|
||||
data["stream"] = False # 修改为非流式模式
|
||||
data["stream"] = False # Change to non-streaming mode
|
||||
response = make_request(url, data)
|
||||
print(f"状态码: {response.status_code}")
|
||||
print_json_response(response.json(), "错误信息")
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
def get_test_cases() -> Dict[str, Callable]:
|
||||
"""获取所有可用的测试用例
|
||||
|
||||
"""Get all available test cases
|
||||
Returns:
|
||||
测试名称到测试函数的映射字典
|
||||
A dictionary mapping test names to test functions
|
||||
"""
|
||||
return {
|
||||
"non_stream": test_non_stream_chat,
|
||||
@@ -490,30 +469,30 @@ def get_test_cases() -> Dict[str, Callable]:
|
||||
}
|
||||
|
||||
def create_default_config():
|
||||
"""创建默认配置文件"""
|
||||
"""Create a default configuration file"""
|
||||
config_path = Path("config.json")
|
||||
if not config_path.exists():
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
||||
print(f"已创建默认配置文件: {config_path}")
|
||||
print(f"Default configuration file created: {config_path}")
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""解析命令行参数"""
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LightRAG Ollama 兼容接口测试",
|
||||
description="LightRAG Ollama Compatibility Interface Testing",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
配置文件 (config.json):
|
||||
Configuration file (config.json):
|
||||
{
|
||||
"server": {
|
||||
"host": "localhost", # 服务器地址
|
||||
"port": 9621, # 服务器端口
|
||||
"model": "lightrag:latest" # 默认模型名称
|
||||
"host": "localhost", # Server address
|
||||
"port": 9621, # Server port
|
||||
"model": "lightrag:latest" # Default model name
|
||||
},
|
||||
"test_cases": {
|
||||
"basic": {
|
||||
"query": "测试查询", # 基本查询文本
|
||||
"stream_query": "流式查询" # 流式查询文本
|
||||
"query": "Test query", # Basic query text
|
||||
"stream_query": "Stream query" # Stream query text
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -522,44 +501,44 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"-q", "--quiet",
|
||||
action="store_true",
|
||||
help="静默模式,只显示测试结果摘要"
|
||||
help="Silent mode, only display test result summary"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--ask",
|
||||
type=str,
|
||||
help="指定查询内容,会覆盖配置文件中的查询设置"
|
||||
help="Specify query content, which will override the query settings in the configuration file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-config",
|
||||
action="store_true",
|
||||
help="创建默认配置文件"
|
||||
help="Create default configuration file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="",
|
||||
help="测试结果输出文件路径,默认不输出到文件"
|
||||
help="Test result output file path, default is not to output to a file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests",
|
||||
nargs="+",
|
||||
choices=list(get_test_cases().keys()) + ["all"],
|
||||
default=["all"],
|
||||
help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试"
|
||||
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# 设置输出模式
|
||||
# Set output mode
|
||||
OutputControl.set_verbose(not args.quiet)
|
||||
|
||||
# 如果指定了查询内容,更新配置
|
||||
# If query content is specified, update the configuration
|
||||
if args.ask:
|
||||
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
||||
|
||||
# 如果指定了创建配置文件
|
||||
# If specified to create a configuration file
|
||||
if args.init_config:
|
||||
create_default_config()
|
||||
exit(0)
|
||||
@@ -568,31 +547,31 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
if "all" in args.tests:
|
||||
# 运行所有测试
|
||||
# Run all tests
|
||||
if OutputControl.is_verbose():
|
||||
print("\n【基本功能测试】")
|
||||
run_test(test_non_stream_chat, "非流式调用测试")
|
||||
run_test(test_stream_chat, "流式调用测试")
|
||||
print("\n【Basic Functionality Tests】")
|
||||
run_test(test_non_stream_chat, "Non-streaming Call Test")
|
||||
run_test(test_stream_chat, "Streaming Call Test")
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n【查询模式测试】")
|
||||
run_test(test_query_modes, "查询模式测试")
|
||||
print("\n【Query Mode Tests】")
|
||||
run_test(test_query_modes, "Query Mode Test")
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n【错误处理测试】")
|
||||
run_test(test_error_handling, "错误处理测试")
|
||||
run_test(test_stream_error_handling, "流式错误处理测试")
|
||||
print("\n【Error Handling Tests】")
|
||||
run_test(test_error_handling, "Error Handling Test")
|
||||
run_test(test_stream_error_handling, "Streaming Error Handling Test")
|
||||
else:
|
||||
# 运行指定的测试
|
||||
# Run specified tests
|
||||
for test_name in args.tests:
|
||||
if OutputControl.is_verbose():
|
||||
print(f"\n【运行测试: {test_name}】")
|
||||
print(f"\n【Running Test: {test_name}】")
|
||||
run_test(test_cases[test_name], test_name)
|
||||
except Exception as e:
|
||||
print(f"\n发生错误: {str(e)}")
|
||||
print(f"\nAn error occurred: {str(e)}")
|
||||
finally:
|
||||
# 打印测试统计
|
||||
# Print test statistics
|
||||
STATS.print_summary()
|
||||
# 如果指定了输出文件路径,则导出结果
|
||||
# If an output file path is specified, export the results
|
||||
if args.output:
|
||||
STATS.export_results(args.output)
|
||||
|
Reference in New Issue
Block a user