Translate unit test comment and promts to English
This commit is contained in:
@@ -19,7 +19,7 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
class OutputControl:
|
class OutputControl:
|
||||||
"""输出控制类,管理测试输出的详细程度"""
|
"""Output control class, manages the verbosity of test output"""
|
||||||
_verbose: bool = False
|
_verbose: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -32,7 +32,7 @@ class OutputControl:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestResult:
|
class TestResult:
|
||||||
"""测试结果数据类"""
|
"""Test result data class"""
|
||||||
name: str
|
name: str
|
||||||
success: bool
|
success: bool
|
||||||
duration: float
|
duration: float
|
||||||
@@ -44,7 +44,7 @@ class TestResult:
|
|||||||
self.timestamp = datetime.now().isoformat()
|
self.timestamp = datetime.now().isoformat()
|
||||||
|
|
||||||
class TestStats:
|
class TestStats:
|
||||||
"""测试统计信息"""
|
"""Test statistics"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.results: List[TestResult] = []
|
self.results: List[TestResult] = []
|
||||||
self.start_time = datetime.now()
|
self.start_time = datetime.now()
|
||||||
@@ -53,10 +53,9 @@ class TestStats:
|
|||||||
self.results.append(result)
|
self.results.append(result)
|
||||||
|
|
||||||
def export_results(self, path: str = "test_results.json"):
|
def export_results(self, path: str = "test_results.json"):
|
||||||
"""导出测试结果到 JSON 文件
|
"""Export test results to a JSON file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: 输出文件路径
|
path: Output file path
|
||||||
"""
|
"""
|
||||||
results_data = {
|
results_data = {
|
||||||
"start_time": self.start_time.isoformat(),
|
"start_time": self.start_time.isoformat(),
|
||||||
@@ -72,7 +71,7 @@ class TestStats:
|
|||||||
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
||||||
print(f"\n测试结果已保存到: {path}")
|
print(f"\nTest results saved to: {path}")
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
total = len(self.results)
|
total = len(self.results)
|
||||||
@@ -80,28 +79,27 @@ class TestStats:
|
|||||||
failed = total - passed
|
failed = total - passed
|
||||||
duration = sum(r.duration for r in self.results)
|
duration = sum(r.duration for r in self.results)
|
||||||
|
|
||||||
print("\n=== 测试结果摘要 ===")
|
print("\n=== Test Summary ===")
|
||||||
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
print(f"总用时: {duration:.2f}秒")
|
print(f"Total duration: {duration:.2f} seconds")
|
||||||
print(f"总计: {total} 个测试")
|
print(f"Total tests: {total}")
|
||||||
print(f"通过: {passed} 个")
|
print(f"Passed: {passed}")
|
||||||
print(f"失败: {failed} 个")
|
print(f"Failed: {failed}")
|
||||||
|
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
print("\n失败的测试:")
|
print("\nFailed tests:")
|
||||||
for result in self.results:
|
for result in self.results:
|
||||||
if not result.success:
|
if not result.success:
|
||||||
print(f"- {result.name}: {result.error}")
|
print(f"- {result.name}: {result.error}")
|
||||||
|
|
||||||
# 默认配置
|
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"server": {
|
"server": {
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 9621,
|
"port": 9621,
|
||||||
"model": "lightrag:latest",
|
"model": "lightrag:latest",
|
||||||
"timeout": 30, # 请求超时时间(秒)
|
"timeout": 30,
|
||||||
"max_retries": 3, # 最大重试次数
|
"max_retries": 3,
|
||||||
"retry_delay": 1 # 重试间隔(秒)
|
"retry_delay": 1
|
||||||
},
|
},
|
||||||
"test_cases": {
|
"test_cases": {
|
||||||
"basic": {
|
"basic": {
|
||||||
@@ -111,18 +109,16 @@ DEFAULT_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
||||||
"""发送 HTTP 请求,支持重试机制
|
"""Send an HTTP request with retry mechanism
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url: 请求 URL
|
url: Request URL
|
||||||
data: 请求数据
|
data: Request data
|
||||||
stream: 是否使用流式响应
|
stream: Whether to use streaming response
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
requests.Response: 对象
|
requests.Response: Response object
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
requests.exceptions.RequestException: 请求失败且重试次数用完
|
requests.exceptions.RequestException: Request failed after all retries
|
||||||
"""
|
"""
|
||||||
server_config = CONFIG["server"]
|
server_config = CONFIG["server"]
|
||||||
max_retries = server_config["max_retries"]
|
max_retries = server_config["max_retries"]
|
||||||
@@ -139,19 +135,18 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
|||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
if attempt == max_retries - 1: # 最后一次重试
|
if attempt == max_retries - 1: # Last retry
|
||||||
raise
|
raise
|
||||||
print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}")
|
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
|
||||||
time.sleep(retry_delay)
|
time.sleep(retry_delay)
|
||||||
|
|
||||||
def load_config() -> Dict[str, Any]:
|
def load_config() -> Dict[str, Any]:
|
||||||
"""加载配置文件
|
"""Load configuration file
|
||||||
|
|
||||||
首先尝试从当前目录的 config.json 加载,
|
|
||||||
如果不存在则使用默认配置
|
|
||||||
|
|
||||||
|
First try to load from config.json in the current directory,
|
||||||
|
if it doesn't exist, use the default configuration
|
||||||
Returns:
|
Returns:
|
||||||
配置字典
|
Configuration dictionary
|
||||||
"""
|
"""
|
||||||
config_path = Path("config.json")
|
config_path = Path("config.json")
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
@@ -160,23 +155,22 @@ def load_config() -> Dict[str, Any]:
|
|||||||
return DEFAULT_CONFIG
|
return DEFAULT_CONFIG
|
||||||
|
|
||||||
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
||||||
"""格式化打印 JSON 响应数据
|
"""Format and print JSON response data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 要打印的数据字典
|
data: Data dictionary to print
|
||||||
title: 打印的标题
|
title: Title to print
|
||||||
indent: JSON 缩进空格数
|
indent: Number of spaces for JSON indentation
|
||||||
"""
|
"""
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
if title:
|
if title:
|
||||||
print(f"\n=== {title} ===")
|
print(f"\n=== {title} ===")
|
||||||
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
||||||
|
|
||||||
# 全局配置
|
# Global configuration
|
||||||
CONFIG = load_config()
|
CONFIG = load_config()
|
||||||
|
|
||||||
def get_base_url() -> str:
|
def get_base_url() -> str:
|
||||||
"""返回基础 URL"""
|
"""Return the base URL"""
|
||||||
server = CONFIG["server"]
|
server = CONFIG["server"]
|
||||||
return f"http://{server['host']}:{server['port']}/api/chat"
|
return f"http://{server['host']}:{server['port']}/api/chat"
|
||||||
|
|
||||||
@@ -185,15 +179,13 @@ def create_request_data(
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
model: str = None
|
model: str = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""创建基本的请求数据
|
"""Create basic request data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 用户消息内容
|
content: User message content
|
||||||
stream: 是否使用流式响应
|
stream: Whether to use streaming response
|
||||||
model: 模型名称
|
model: Model name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含完整请求数据的字典
|
Dictionary containing complete request data
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"model": model or CONFIG["server"]["model"],
|
"model": model or CONFIG["server"]["model"],
|
||||||
@@ -206,15 +198,14 @@ def create_request_data(
|
|||||||
"stream": stream
|
"stream": stream
|
||||||
}
|
}
|
||||||
|
|
||||||
# 全局测试统计
|
# Global test statistics
|
||||||
STATS = TestStats()
|
STATS = TestStats()
|
||||||
|
|
||||||
def run_test(func: Callable, name: str) -> None:
|
def run_test(func: Callable, name: str) -> None:
|
||||||
"""运行测试并记录结果
|
"""Run a test and record the results
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: 测试函数
|
func: Test function
|
||||||
name: 测试名称
|
name: Test name
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
@@ -227,54 +218,43 @@ def run_test(func: Callable, name: str) -> None:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def test_non_stream_chat():
|
def test_non_stream_chat():
|
||||||
"""测试非流式调用 /api/chat 接口"""
|
"""Test non-streaming call to /api/chat endpoint"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
data = create_request_data(
|
data = create_request_data(
|
||||||
CONFIG["test_cases"]["basic"]["query"],
|
CONFIG["test_cases"]["basic"]["query"],
|
||||||
stream=False
|
stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送请求
|
# Send request
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
|
|
||||||
# 打印响应
|
# Print response
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n=== 非流式调用响应 ===")
|
print("\n=== Non-streaming call response ===")
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
# 打印响应内容
|
# Print response content
|
||||||
print_json_response({
|
print_json_response({
|
||||||
"model": response_json["model"],
|
"model": response_json["model"],
|
||||||
"message": response_json["message"]
|
"message": response_json["message"]
|
||||||
}, "响应内容")
|
}, "Response content")
|
||||||
|
|
||||||
# # 打印性能统计
|
|
||||||
# print_json_response({
|
|
||||||
# "total_duration": response_json["total_duration"],
|
|
||||||
# "load_duration": response_json["load_duration"],
|
|
||||||
# "prompt_eval_count": response_json["prompt_eval_count"],
|
|
||||||
# "prompt_eval_duration": response_json["prompt_eval_duration"],
|
|
||||||
# "eval_count": response_json["eval_count"],
|
|
||||||
# "eval_duration": response_json["eval_duration"]
|
|
||||||
# }, "性能统计")
|
|
||||||
|
|
||||||
def test_stream_chat():
|
def test_stream_chat():
|
||||||
"""测试流式调用 /api/chat 接口
|
"""Test streaming call to /api/chat endpoint
|
||||||
|
|
||||||
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
|
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
||||||
响应格式:
|
Response format:
|
||||||
{
|
{
|
||||||
"model": "lightrag:latest",
|
"model": "lightrag:latest",
|
||||||
"created_at": "2024-01-15T00:00:00Z",
|
"created_at": "2024-01-15T00:00:00Z",
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "部分响应内容",
|
"content": "Partial response content",
|
||||||
"images": null
|
"images": null
|
||||||
},
|
},
|
||||||
"done": false
|
"done": false
|
||||||
}
|
}
|
||||||
|
|
||||||
最后一条消息会包含性能统计信息,done 为 true。
|
The last message will contain performance statistics, with done set to true.
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
data = create_request_data(
|
data = create_request_data(
|
||||||
@@ -282,79 +262,79 @@ def test_stream_chat():
|
|||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送请求并获取流式响应
|
# Send request and get streaming response
|
||||||
response = make_request(url, data, stream=True)
|
response = make_request(url, data, stream=True)
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n=== 流式调用响应 ===")
|
print("\n=== Streaming call response ===")
|
||||||
output_buffer = []
|
output_buffer = []
|
||||||
try:
|
try:
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if line: # 跳过空行
|
if line: # Skip empty lines
|
||||||
try:
|
try:
|
||||||
# 解码并解析 JSON
|
# Decode and parse JSON
|
||||||
data = json.loads(line.decode('utf-8'))
|
data = json.loads(line.decode('utf-8'))
|
||||||
if data.get("done", True): # 如果是完成标记
|
if data.get("done", True): # If it's the completion marker
|
||||||
if "total_duration" in data: # 最终的性能统计消息
|
if "total_duration" in data: # Final performance statistics message
|
||||||
# print_json_response(data, "性能统计")
|
# print_json_response(data, "Performance statistics")
|
||||||
break
|
break
|
||||||
else: # 正常的内容消息
|
else: # Normal content message
|
||||||
message = data.get("message", {})
|
message = data.get("message", {})
|
||||||
content = message.get("content", "")
|
content = message.get("content", "")
|
||||||
if content: # 只收集非空内容
|
if content: # Only collect non-empty content
|
||||||
output_buffer.append(content)
|
output_buffer.append(content)
|
||||||
print(content, end="", flush=True) # 实时打印内容
|
print(content, end="", flush=True) # Print content in real-time
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print("Error decoding JSON from response line")
|
print("Error decoding JSON from response line")
|
||||||
finally:
|
finally:
|
||||||
response.close() # 确保关闭响应连接
|
response.close() # Ensure the response connection is closed
|
||||||
|
|
||||||
# 打印一个换行
|
# Print a newline
|
||||||
print()
|
print()
|
||||||
|
|
||||||
def test_query_modes():
|
def test_query_modes():
|
||||||
"""测试不同的查询模式前缀
|
"""Test different query mode prefixes
|
||||||
|
|
||||||
支持的查询模式:
|
Supported query modes:
|
||||||
- /local: 本地检索模式,只在相关度高的文档中搜索
|
- /local: Local retrieval mode, searches only in highly relevant documents
|
||||||
- /global: 全局检索模式,在所有文档中搜索
|
- /global: Global retrieval mode, searches across all documents
|
||||||
- /naive: 朴素模式,不使用任何优化策略
|
- /naive: Naive mode, does not use any optimization strategies
|
||||||
- /hybrid: 混合模式(默认),结合多种策略
|
- /hybrid: Hybrid mode (default), combines multiple strategies
|
||||||
|
- /mix: Mix mode
|
||||||
|
|
||||||
每个模式都会返回相同格式的响应,但检索策略不同。
|
Each mode will return responses in the same format, but with different retrieval strategies.
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
|
modes = ["local", "global", "naive", "hybrid", "mix"]
|
||||||
|
|
||||||
for mode in modes:
|
for mode in modes:
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print(f"\n=== 测试 /{mode} 模式 ===")
|
print(f"\n=== Testing /{mode} mode ===")
|
||||||
data = create_request_data(
|
data = create_request_data(
|
||||||
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
||||||
stream=False
|
stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送请求
|
# Send request
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
# 打印响应内容
|
# Print response content
|
||||||
print_json_response({
|
print_json_response({
|
||||||
"model": response_json["model"],
|
"model": response_json["model"],
|
||||||
"message": response_json["message"]
|
"message": response_json["message"]
|
||||||
})
|
})
|
||||||
|
|
||||||
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||||
"""创建用于错误测试的请求数据
|
"""Create request data for error testing
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
error_type: 错误类型,支持:
|
error_type: Error type, supported:
|
||||||
- empty_messages: 空消息列表
|
- empty_messages: Empty message list
|
||||||
- invalid_role: 无效的角色字段
|
- invalid_role: Invalid role field
|
||||||
- missing_content: 缺少内容字段
|
- missing_content: Missing content field
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含错误数据的请求字典
|
Request dictionary containing error data
|
||||||
"""
|
"""
|
||||||
error_data = {
|
error_data = {
|
||||||
"empty_messages": {
|
"empty_messages": {
|
||||||
@@ -367,7 +347,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
|||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"invalid_role": "user",
|
"invalid_role": "user",
|
||||||
"content": "测试消息"
|
"content": "Test message"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"stream": True
|
"stream": True
|
||||||
@@ -385,101 +365,100 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
|||||||
return error_data.get(error_type, error_data["empty_messages"])
|
return error_data.get(error_type, error_data["empty_messages"])
|
||||||
|
|
||||||
def test_stream_error_handling():
|
def test_stream_error_handling():
|
||||||
"""测试流式响应的错误处理
|
"""Test error handling for streaming responses
|
||||||
|
|
||||||
测试场景:
|
Test scenarios:
|
||||||
1. 空消息列表
|
1. Empty message list
|
||||||
2. 消息格式错误(缺少必需字段)
|
2. Message format error (missing required fields)
|
||||||
|
|
||||||
错误响应会立即返回,不会建立流式连接。
|
Error responses should be returned immediately without establishing a streaming connection.
|
||||||
状态码应该是 4xx,并返回详细的错误信息。
|
The status code should be 4xx, and detailed error information should be returned.
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n=== 测试流式响应错误处理 ===")
|
print("\n=== Testing streaming response error handling ===")
|
||||||
|
|
||||||
# 测试空消息列表
|
# Test empty message list
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试空消息列表(流式)---")
|
print("\n--- Testing empty message list (streaming) ---")
|
||||||
data = create_error_test_data("empty_messages")
|
data = create_error_test_data("empty_messages")
|
||||||
response = make_request(url, data, stream=True)
|
response = make_request(url, data, stream=True)
|
||||||
print(f"状态码: {response.status_code}")
|
print(f"Status code: {response.status_code}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print_json_response(response.json(), "错误信息")
|
print_json_response(response.json(), "Error message")
|
||||||
response.close()
|
response.close()
|
||||||
|
|
||||||
# 测试无效角色字段
|
# Test invalid role field
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试无效角色字段(流式)---")
|
print("\n--- Testing invalid role field (streaming) ---")
|
||||||
data = create_error_test_data("invalid_role")
|
data = create_error_test_data("invalid_role")
|
||||||
response = make_request(url, data, stream=True)
|
response = make_request(url, data, stream=True)
|
||||||
print(f"状态码: {response.status_code}")
|
print(f"Status code: {response.status_code}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print_json_response(response.json(), "错误信息")
|
print_json_response(response.json(), "Error message")
|
||||||
response.close()
|
response.close()
|
||||||
|
|
||||||
# 测试缺少内容字段
|
# Test missing content field
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试缺少内容字段(流式)---")
|
print("\n--- Testing missing content field (streaming) ---")
|
||||||
data = create_error_test_data("missing_content")
|
data = create_error_test_data("missing_content")
|
||||||
response = make_request(url, data, stream=True)
|
response = make_request(url, data, stream=True)
|
||||||
print(f"状态码: {response.status_code}")
|
print(f"Status code: {response.status_code}")
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print_json_response(response.json(), "错误信息")
|
print_json_response(response.json(), "Error message")
|
||||||
response.close()
|
response.close()
|
||||||
|
|
||||||
def test_error_handling():
|
def test_error_handling():
|
||||||
"""测试非流式响应的错误处理
|
"""Test error handling for non-streaming responses
|
||||||
|
|
||||||
测试场景:
|
Test scenarios:
|
||||||
1. 空消息列表
|
1. Empty message list
|
||||||
2. 消息格式错误(缺少必需字段)
|
2. Message format error (missing required fields)
|
||||||
|
|
||||||
错误响应格式:
|
Error response format:
|
||||||
{
|
{
|
||||||
"detail": "错误描述"
|
"detail": "Error description"
|
||||||
}
|
}
|
||||||
|
|
||||||
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
|
All errors should return appropriate HTTP status codes and clear error messages.
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n=== 测试错误处理 ===")
|
print("\n=== Testing error handling ===")
|
||||||
|
|
||||||
# 测试空消息列表
|
# Test empty message list
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试空消息列表 ---")
|
print("\n--- Testing empty message list ---")
|
||||||
data = create_error_test_data("empty_messages")
|
data = create_error_test_data("empty_messages")
|
||||||
data["stream"] = False # 修改为非流式模式
|
data["stream"] = False # Change to non-streaming mode
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
print(f"状态码: {response.status_code}")
|
print(f"Status code: {response.status_code}")
|
||||||
print_json_response(response.json(), "错误信息")
|
print_json_response(response.json(), "Error message")
|
||||||
|
|
||||||
# 测试无效角色字段
|
# Test invalid role field
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试无效角色字段 ---")
|
print("\n--- Testing invalid role field ---")
|
||||||
data = create_error_test_data("invalid_role")
|
data = create_error_test_data("invalid_role")
|
||||||
data["stream"] = False # 修改为非流式模式
|
data["stream"] = False # Change to non-streaming mode
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
print(f"状态码: {response.status_code}")
|
print(f"Status code: {response.status_code}")
|
||||||
print_json_response(response.json(), "错误信息")
|
print_json_response(response.json(), "Error message")
|
||||||
|
|
||||||
# 测试缺少内容字段
|
# Test missing content field
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试缺少内容字段 ---")
|
print("\n--- Testing missing content field ---")
|
||||||
data = create_error_test_data("missing_content")
|
data = create_error_test_data("missing_content")
|
||||||
data["stream"] = False # 修改为非流式模式
|
data["stream"] = False # Change to non-streaming mode
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
print(f"状态码: {response.status_code}")
|
print(f"Status code: {response.status_code}")
|
||||||
print_json_response(response.json(), "错误信息")
|
print_json_response(response.json(), "Error message")
|
||||||
|
|
||||||
def get_test_cases() -> Dict[str, Callable]:
|
def get_test_cases() -> Dict[str, Callable]:
|
||||||
"""获取所有可用的测试用例
|
"""Get all available test cases
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
测试名称到测试函数的映射字典
|
A dictionary mapping test names to test functions
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"non_stream": test_non_stream_chat,
|
"non_stream": test_non_stream_chat,
|
||||||
@@ -490,30 +469,30 @@ def get_test_cases() -> Dict[str, Callable]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def create_default_config():
|
def create_default_config():
|
||||||
"""创建默认配置文件"""
|
"""Create a default configuration file"""
|
||||||
config_path = Path("config.json")
|
config_path = Path("config.json")
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
||||||
print(f"已创建默认配置文件: {config_path}")
|
print(f"Default configuration file created: {config_path}")
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
"""解析命令行参数"""
|
"""Parse command line arguments"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="LightRAG Ollama 兼容接口测试",
|
description="LightRAG Ollama Compatibility Interface Testing",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
配置文件 (config.json):
|
Configuration file (config.json):
|
||||||
{
|
{
|
||||||
"server": {
|
"server": {
|
||||||
"host": "localhost", # 服务器地址
|
"host": "localhost", # Server address
|
||||||
"port": 9621, # 服务器端口
|
"port": 9621, # Server port
|
||||||
"model": "lightrag:latest" # 默认模型名称
|
"model": "lightrag:latest" # Default model name
|
||||||
},
|
},
|
||||||
"test_cases": {
|
"test_cases": {
|
||||||
"basic": {
|
"basic": {
|
||||||
"query": "测试查询", # 基本查询文本
|
"query": "Test query", # Basic query text
|
||||||
"stream_query": "流式查询" # 流式查询文本
|
"stream_query": "Stream query" # Stream query text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -522,44 +501,44 @@ def parse_args() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-q", "--quiet",
|
"-q", "--quiet",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="静默模式,只显示测试结果摘要"
|
help="Silent mode, only display test result summary"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-a", "--ask",
|
"-a", "--ask",
|
||||||
type=str,
|
type=str,
|
||||||
help="指定查询内容,会覆盖配置文件中的查询设置"
|
help="Specify query content, which will override the query settings in the configuration file"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--init-config",
|
"--init-config",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="创建默认配置文件"
|
help="Create default configuration file"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="测试结果输出文件路径,默认不输出到文件"
|
help="Test result output file path, default is not to output to a file"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tests",
|
"--tests",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
choices=list(get_test_cases().keys()) + ["all"],
|
choices=list(get_test_cases().keys()) + ["all"],
|
||||||
default=["all"],
|
default=["all"],
|
||||||
help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试"
|
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests"
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 设置输出模式
|
# Set output mode
|
||||||
OutputControl.set_verbose(not args.quiet)
|
OutputControl.set_verbose(not args.quiet)
|
||||||
|
|
||||||
# 如果指定了查询内容,更新配置
|
# If query content is specified, update the configuration
|
||||||
if args.ask:
|
if args.ask:
|
||||||
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
||||||
|
|
||||||
# 如果指定了创建配置文件
|
# If specified to create a configuration file
|
||||||
if args.init_config:
|
if args.init_config:
|
||||||
create_default_config()
|
create_default_config()
|
||||||
exit(0)
|
exit(0)
|
||||||
@@ -568,31 +547,31 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if "all" in args.tests:
|
if "all" in args.tests:
|
||||||
# 运行所有测试
|
# Run all tests
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n【基本功能测试】")
|
print("\n【Basic Functionality Tests】")
|
||||||
run_test(test_non_stream_chat, "非流式调用测试")
|
run_test(test_non_stream_chat, "Non-streaming Call Test")
|
||||||
run_test(test_stream_chat, "流式调用测试")
|
run_test(test_stream_chat, "Streaming Call Test")
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n【查询模式测试】")
|
print("\n【Query Mode Tests】")
|
||||||
run_test(test_query_modes, "查询模式测试")
|
run_test(test_query_modes, "Query Mode Test")
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n【错误处理测试】")
|
print("\n【Error Handling Tests】")
|
||||||
run_test(test_error_handling, "错误处理测试")
|
run_test(test_error_handling, "Error Handling Test")
|
||||||
run_test(test_stream_error_handling, "流式错误处理测试")
|
run_test(test_stream_error_handling, "Streaming Error Handling Test")
|
||||||
else:
|
else:
|
||||||
# 运行指定的测试
|
# Run specified tests
|
||||||
for test_name in args.tests:
|
for test_name in args.tests:
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print(f"\n【运行测试: {test_name}】")
|
print(f"\n【Running Test: {test_name}】")
|
||||||
run_test(test_cases[test_name], test_name)
|
run_test(test_cases[test_name], test_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n发生错误: {str(e)}")
|
print(f"\nAn error occurred: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
# 打印测试统计
|
# Print test statistics
|
||||||
STATS.print_summary()
|
STATS.print_summary()
|
||||||
# 如果指定了输出文件路径,则导出结果
|
# If an output file path is specified, export the results
|
||||||
if args.output:
|
if args.output:
|
||||||
STATS.export_results(args.output)
|
STATS.export_results(args.output)
|
||||||
|
Reference in New Issue
Block a user