Translate comment to English
This commit is contained in:
@@ -27,15 +27,15 @@ from dotenv import load_dotenv
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
def estimate_tokens(text: str) -> int:
|
def estimate_tokens(text: str) -> int:
|
||||||
"""估算文本的token数量
|
"""Estimate the number of tokens in text
|
||||||
中文每字约1.5个token
|
Chinese characters: approximately 1.5 tokens per character
|
||||||
英文每字约0.25个token
|
English characters: approximately 0.25 tokens per character
|
||||||
"""
|
"""
|
||||||
# 使用正则表达式分别匹配中文字符和非中文字符
|
# Use regex to match Chinese and non-Chinese characters separately
|
||||||
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
||||||
non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
|
non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
|
||||||
|
|
||||||
# 计算估算的token数量
|
# Calculate estimated token count
|
||||||
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
||||||
|
|
||||||
return int(tokens)
|
return int(tokens)
|
||||||
@@ -241,7 +241,7 @@ class DocumentManager:
|
|||||||
class SearchMode(str, Enum):
|
class SearchMode(str, Enum):
|
||||||
naive = "naive"
|
naive = "naive"
|
||||||
local = "local"
|
local = "local"
|
||||||
global_ = "global" # 使用 global_ 因为 global 是 Python 保留关键字,但枚举值会转换为字符串 "global"
|
global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global"
|
||||||
hybrid = "hybrid"
|
hybrid = "hybrid"
|
||||||
mix = "mix"
|
mix = "mix"
|
||||||
|
|
||||||
@@ -254,7 +254,7 @@ class OllamaMessage(BaseModel):
|
|||||||
class OllamaChatRequest(BaseModel):
|
class OllamaChatRequest(BaseModel):
|
||||||
model: str = LIGHTRAG_MODEL
|
model: str = LIGHTRAG_MODEL
|
||||||
messages: List[OllamaMessage]
|
messages: List[OllamaMessage]
|
||||||
stream: bool = True # 默认为流式模式
|
stream: bool = True # Default to streaming mode
|
||||||
options: Optional[Dict[str, Any]] = None
|
options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
class OllamaChatResponse(BaseModel):
|
class OllamaChatResponse(BaseModel):
|
||||||
@@ -490,11 +490,11 @@ def create_app(args):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果响应是字符串(比如命中缓存),直接返回
|
# If response is a string (e.g. cache hit), return directly
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
return QueryResponse(response=response)
|
return QueryResponse(response=response)
|
||||||
|
|
||||||
# 如果是异步生成器,根据stream参数决定是否流式返回
|
# If it's an async generator, decide whether to stream based on stream parameter
|
||||||
if request.stream:
|
if request.stream:
|
||||||
result = ""
|
result = ""
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
@@ -511,7 +511,7 @@ def create_app(args):
|
|||||||
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||||
async def query_text_stream(request: QueryRequest):
|
async def query_text_stream(request: QueryRequest):
|
||||||
try:
|
try:
|
||||||
response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await
|
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||||
request.query,
|
request.query,
|
||||||
param=QueryParam(
|
param=QueryParam(
|
||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
@@ -691,7 +691,7 @@ def create_app(args):
|
|||||||
|
|
||||||
for prefix, mode in mode_map.items():
|
for prefix, mode in mode_map.items():
|
||||||
if query.startswith(prefix):
|
if query.startswith(prefix):
|
||||||
# 移除前缀后,清理开头的额外空格
|
# After removing prefix an leading spaces
|
||||||
cleaned_query = query[len(prefix):].lstrip()
|
cleaned_query = query[len(prefix):].lstrip()
|
||||||
return cleaned_query, mode
|
return cleaned_query, mode
|
||||||
|
|
||||||
@@ -699,17 +699,14 @@ def create_app(args):
|
|||||||
|
|
||||||
@app.post("/api/chat")
|
@app.post("/api/chat")
|
||||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||||
# # 打印原始请求数据
|
|
||||||
# body = await raw_request.body()
|
|
||||||
# logging.info(f"收到 /api/chat 原始请求: {body.decode('utf-8')}")
|
|
||||||
"""Handle chat completion requests"""
|
"""Handle chat completion requests"""
|
||||||
try:
|
try:
|
||||||
# 获取所有消息内容
|
# Get all messages
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
if not messages:
|
if not messages:
|
||||||
raise HTTPException(status_code=400, detail="No messages provided")
|
raise HTTPException(status_code=400, detail="No messages provided")
|
||||||
|
|
||||||
# 获取最后一条消息作为查询
|
# Get the last message as query
|
||||||
query = messages[-1].content
|
query = messages[-1].content
|
||||||
|
|
||||||
# 解析查询模式
|
# 解析查询模式
|
||||||
@@ -723,7 +720,7 @@ def create_app(args):
|
|||||||
|
|
||||||
# 调用RAG进行查询
|
# 调用RAG进行查询
|
||||||
query_param = QueryParam(
|
query_param = QueryParam(
|
||||||
mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid
|
mode=mode,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
only_need_context=False
|
only_need_context=False
|
||||||
)
|
)
|
||||||
@@ -731,7 +728,7 @@ def create_app(args):
|
|||||||
if request.stream:
|
if request.stream:
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
response = await rag.aquery( # 需要 await 来获取异步生成器
|
response = await rag.aquery( # Need await to get async generator
|
||||||
cleaned_query,
|
cleaned_query,
|
||||||
param=query_param
|
param=query_param
|
||||||
)
|
)
|
||||||
@@ -742,9 +739,9 @@ def create_app(args):
|
|||||||
last_chunk_time = None
|
last_chunk_time = None
|
||||||
total_response = ""
|
total_response = ""
|
||||||
|
|
||||||
# 确保 response 是异步生成器
|
# Ensure response is an async generator
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
# 如果是字符串,分两次发送
|
# If it's a string, send in two parts
|
||||||
first_chunk_time = time.time_ns()
|
first_chunk_time = time.time_ns()
|
||||||
last_chunk_time = first_chunk_time
|
last_chunk_time = first_chunk_time
|
||||||
total_response = response
|
total_response = response
|
||||||
|
@@ -1,12 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
LightRAG Ollama 兼容接口测试脚本
|
LightRAG Ollama Compatibility Interface Test Script
|
||||||
|
|
||||||
这个脚本测试 LightRAG 的 Ollama 兼容接口,包括:
|
This script tests the LightRAG's Ollama compatibility interface, including:
|
||||||
1. 基本功能测试(流式和非流式响应)
|
1. Basic functionality tests (streaming and non-streaming responses)
|
||||||
2. 查询模式测试(local、global、naive、hybrid)
|
2. Query mode tests (local, global, naive, hybrid)
|
||||||
3. 错误处理测试(包括流式和非流式场景)
|
3. Error handling tests (including streaming and non-streaming scenarios)
|
||||||
|
|
||||||
所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。
|
All responses use the JSON Lines format, complying with the Ollama API specification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -24,20 +24,10 @@ class OutputControl:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_verbose(cls, verbose: bool) -> None:
|
def set_verbose(cls, verbose: bool) -> None:
|
||||||
"""设置输出详细程度
|
|
||||||
|
|
||||||
Args:
|
|
||||||
verbose: True 为详细模式,False 为静默模式
|
|
||||||
"""
|
|
||||||
cls._verbose = verbose
|
cls._verbose = verbose
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_verbose(cls) -> bool:
|
def is_verbose(cls) -> bool:
|
||||||
"""获取当前输出模式
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
当前是否为详细模式
|
|
||||||
"""
|
|
||||||
return cls._verbose
|
return cls._verbose
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -48,9 +38,8 @@ class TestResult:
|
|||||||
duration: float
|
duration: float
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
timestamp: str = ""
|
timestamp: str = ""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""初始化后设置时间戳"""
|
|
||||||
if not self.timestamp:
|
if not self.timestamp:
|
||||||
self.timestamp = datetime.now().isoformat()
|
self.timestamp = datetime.now().isoformat()
|
||||||
|
|
||||||
@@ -59,14 +48,13 @@ class TestStats:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.results: List[TestResult] = []
|
self.results: List[TestResult] = []
|
||||||
self.start_time = datetime.now()
|
self.start_time = datetime.now()
|
||||||
|
|
||||||
def add_result(self, result: TestResult):
|
def add_result(self, result: TestResult):
|
||||||
"""添加测试结果"""
|
|
||||||
self.results.append(result)
|
self.results.append(result)
|
||||||
|
|
||||||
def export_results(self, path: str = "test_results.json"):
|
def export_results(self, path: str = "test_results.json"):
|
||||||
"""导出测试结果到 JSON 文件
|
"""导出测试结果到 JSON 文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: 输出文件路径
|
path: 输出文件路径
|
||||||
"""
|
"""
|
||||||
@@ -81,25 +69,24 @@ class TestStats:
|
|||||||
"total_duration": sum(r.duration for r in self.results)
|
"total_duration": sum(r.duration for r in self.results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
||||||
print(f"\n测试结果已保存到: {path}")
|
print(f"\n测试结果已保存到: {path}")
|
||||||
|
|
||||||
def print_summary(self):
|
def print_summary(self):
|
||||||
"""打印测试统计摘要"""
|
|
||||||
total = len(self.results)
|
total = len(self.results)
|
||||||
passed = sum(1 for r in self.results if r.success)
|
passed = sum(1 for r in self.results if r.success)
|
||||||
failed = total - passed
|
failed = total - passed
|
||||||
duration = sum(r.duration for r in self.results)
|
duration = sum(r.duration for r in self.results)
|
||||||
|
|
||||||
print("\n=== 测试结果摘要 ===")
|
print("\n=== 测试结果摘要 ===")
|
||||||
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
print(f"总用时: {duration:.2f}秒")
|
print(f"总用时: {duration:.2f}秒")
|
||||||
print(f"总计: {total} 个测试")
|
print(f"总计: {total} 个测试")
|
||||||
print(f"通过: {passed} 个")
|
print(f"通过: {passed} 个")
|
||||||
print(f"失败: {failed} 个")
|
print(f"失败: {failed} 个")
|
||||||
|
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
print("\n失败的测试:")
|
print("\n失败的测试:")
|
||||||
for result in self.results:
|
for result in self.results:
|
||||||
@@ -125,15 +112,15 @@ DEFAULT_CONFIG = {
|
|||||||
|
|
||||||
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
||||||
"""发送 HTTP 请求,支持重试机制
|
"""发送 HTTP 请求,支持重试机制
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url: 请求 URL
|
url: 请求 URL
|
||||||
data: 请求数据
|
data: 请求数据
|
||||||
stream: 是否使用流式响应
|
stream: 是否使用流式响应
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
requests.Response 对象
|
requests.Response: 对象
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
requests.exceptions.RequestException: 请求失败且重试次数用完
|
requests.exceptions.RequestException: 请求失败且重试次数用完
|
||||||
"""
|
"""
|
||||||
@@ -141,7 +128,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
|||||||
max_retries = server_config["max_retries"]
|
max_retries = server_config["max_retries"]
|
||||||
retry_delay = server_config["retry_delay"]
|
retry_delay = server_config["retry_delay"]
|
||||||
timeout = server_config["timeout"]
|
timeout = server_config["timeout"]
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@@ -159,10 +146,10 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
|||||||
|
|
||||||
def load_config() -> Dict[str, Any]:
|
def load_config() -> Dict[str, Any]:
|
||||||
"""加载配置文件
|
"""加载配置文件
|
||||||
|
|
||||||
首先尝试从当前目录的 config.json 加载,
|
首先尝试从当前目录的 config.json 加载,
|
||||||
如果不存在则使用默认配置
|
如果不存在则使用默认配置
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
配置字典
|
配置字典
|
||||||
"""
|
"""
|
||||||
@@ -174,7 +161,7 @@ def load_config() -> Dict[str, Any]:
|
|||||||
|
|
||||||
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
||||||
"""格式化打印 JSON 响应数据
|
"""格式化打印 JSON 响应数据
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 要打印的数据字典
|
data: 要打印的数据字典
|
||||||
title: 打印的标题
|
title: 打印的标题
|
||||||
@@ -199,12 +186,12 @@ def create_request_data(
|
|||||||
model: str = None
|
model: str = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""创建基本的请求数据
|
"""创建基本的请求数据
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 用户消息内容
|
content: 用户消息内容
|
||||||
stream: 是否使用流式响应
|
stream: 是否使用流式响应
|
||||||
model: 模型名称
|
model: 模型名称
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含完整请求数据的字典
|
包含完整请求数据的字典
|
||||||
"""
|
"""
|
||||||
@@ -224,7 +211,7 @@ STATS = TestStats()
|
|||||||
|
|
||||||
def run_test(func: Callable, name: str) -> None:
|
def run_test(func: Callable, name: str) -> None:
|
||||||
"""运行测试并记录结果
|
"""运行测试并记录结果
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: 测试函数
|
func: 测试函数
|
||||||
name: 测试名称
|
name: 测试名称
|
||||||
@@ -246,21 +233,21 @@ def test_non_stream_chat():
|
|||||||
CONFIG["test_cases"]["basic"]["query"],
|
CONFIG["test_cases"]["basic"]["query"],
|
||||||
stream=False
|
stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送请求
|
# 发送请求
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
|
|
||||||
# 打印响应
|
# 打印响应
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n=== 非流式调用响应 ===")
|
print("\n=== 非流式调用响应 ===")
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
# 打印响应内容
|
# 打印响应内容
|
||||||
print_json_response({
|
print_json_response({
|
||||||
"model": response_json["model"],
|
"model": response_json["model"],
|
||||||
"message": response_json["message"]
|
"message": response_json["message"]
|
||||||
}, "响应内容")
|
}, "响应内容")
|
||||||
|
|
||||||
# # 打印性能统计
|
# # 打印性能统计
|
||||||
# print_json_response({
|
# print_json_response({
|
||||||
# "total_duration": response_json["total_duration"],
|
# "total_duration": response_json["total_duration"],
|
||||||
@@ -273,7 +260,7 @@ def test_non_stream_chat():
|
|||||||
|
|
||||||
def test_stream_chat():
|
def test_stream_chat():
|
||||||
"""测试流式调用 /api/chat 接口
|
"""测试流式调用 /api/chat 接口
|
||||||
|
|
||||||
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
|
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
|
||||||
响应格式:
|
响应格式:
|
||||||
{
|
{
|
||||||
@@ -286,7 +273,7 @@ def test_stream_chat():
|
|||||||
},
|
},
|
||||||
"done": false
|
"done": false
|
||||||
}
|
}
|
||||||
|
|
||||||
最后一条消息会包含性能统计信息,done 为 true。
|
最后一条消息会包含性能统计信息,done 为 true。
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
@@ -294,10 +281,10 @@ def test_stream_chat():
|
|||||||
CONFIG["test_cases"]["basic"]["query"],
|
CONFIG["test_cases"]["basic"]["query"],
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送请求并获取流式响应
|
# 发送请求并获取流式响应
|
||||||
response = make_request(url, data, stream=True)
|
response = make_request(url, data, stream=True)
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n=== 流式调用响应 ===")
|
print("\n=== 流式调用响应 ===")
|
||||||
output_buffer = []
|
output_buffer = []
|
||||||
@@ -321,24 +308,24 @@ def test_stream_chat():
|
|||||||
print("Error decoding JSON from response line")
|
print("Error decoding JSON from response line")
|
||||||
finally:
|
finally:
|
||||||
response.close() # 确保关闭响应连接
|
response.close() # 确保关闭响应连接
|
||||||
|
|
||||||
# 打印一个换行
|
# 打印一个换行
|
||||||
print()
|
print()
|
||||||
|
|
||||||
def test_query_modes():
|
def test_query_modes():
|
||||||
"""测试不同的查询模式前缀
|
"""测试不同的查询模式前缀
|
||||||
|
|
||||||
支持的查询模式:
|
支持的查询模式:
|
||||||
- /local: 本地检索模式,只在相关度高的文档中搜索
|
- /local: 本地检索模式,只在相关度高的文档中搜索
|
||||||
- /global: 全局检索模式,在所有文档中搜索
|
- /global: 全局检索模式,在所有文档中搜索
|
||||||
- /naive: 朴素模式,不使用任何优化策略
|
- /naive: 朴素模式,不使用任何优化策略
|
||||||
- /hybrid: 混合模式(默认),结合多种策略
|
- /hybrid: 混合模式(默认),结合多种策略
|
||||||
|
|
||||||
每个模式都会返回相同格式的响应,但检索策略不同。
|
每个模式都会返回相同格式的响应,但检索策略不同。
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
|
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
|
||||||
|
|
||||||
for mode in modes:
|
for mode in modes:
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print(f"\n=== 测试 /{mode} 模式 ===")
|
print(f"\n=== 测试 /{mode} 模式 ===")
|
||||||
@@ -346,11 +333,11 @@ def test_query_modes():
|
|||||||
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
||||||
stream=False
|
stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送请求
|
# 发送请求
|
||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
|
|
||||||
# 打印响应内容
|
# 打印响应内容
|
||||||
print_json_response({
|
print_json_response({
|
||||||
"model": response_json["model"],
|
"model": response_json["model"],
|
||||||
@@ -359,13 +346,13 @@ def test_query_modes():
|
|||||||
|
|
||||||
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||||
"""创建用于错误测试的请求数据
|
"""创建用于错误测试的请求数据
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
error_type: 错误类型,支持:
|
error_type: 错误类型,支持:
|
||||||
- empty_messages: 空消息列表
|
- empty_messages: 空消息列表
|
||||||
- invalid_role: 无效的角色字段
|
- invalid_role: 无效的角色字段
|
||||||
- missing_content: 缺少内容字段
|
- missing_content: 缺少内容字段
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含错误数据的请求字典
|
包含错误数据的请求字典
|
||||||
"""
|
"""
|
||||||
@@ -399,19 +386,19 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
def test_stream_error_handling():
|
def test_stream_error_handling():
|
||||||
"""测试流式响应的错误处理
|
"""测试流式响应的错误处理
|
||||||
|
|
||||||
测试场景:
|
测试场景:
|
||||||
1. 空消息列表
|
1. 空消息列表
|
||||||
2. 消息格式错误(缺少必需字段)
|
2. 消息格式错误(缺少必需字段)
|
||||||
|
|
||||||
错误响应会立即返回,不会建立流式连接。
|
错误响应会立即返回,不会建立流式连接。
|
||||||
状态码应该是 4xx,并返回详细的错误信息。
|
状态码应该是 4xx,并返回详细的错误信息。
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n=== 测试流式响应错误处理 ===")
|
print("\n=== 测试流式响应错误处理 ===")
|
||||||
|
|
||||||
# 测试空消息列表
|
# 测试空消息列表
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试空消息列表(流式)---")
|
print("\n--- 测试空消息列表(流式)---")
|
||||||
@@ -421,7 +408,7 @@ def test_stream_error_handling():
|
|||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print_json_response(response.json(), "错误信息")
|
print_json_response(response.json(), "错误信息")
|
||||||
response.close()
|
response.close()
|
||||||
|
|
||||||
# 测试无效角色字段
|
# 测试无效角色字段
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试无效角色字段(流式)---")
|
print("\n--- 测试无效角色字段(流式)---")
|
||||||
@@ -444,23 +431,23 @@ def test_stream_error_handling():
|
|||||||
|
|
||||||
def test_error_handling():
|
def test_error_handling():
|
||||||
"""测试非流式响应的错误处理
|
"""测试非流式响应的错误处理
|
||||||
|
|
||||||
测试场景:
|
测试场景:
|
||||||
1. 空消息列表
|
1. 空消息列表
|
||||||
2. 消息格式错误(缺少必需字段)
|
2. 消息格式错误(缺少必需字段)
|
||||||
|
|
||||||
错误响应格式:
|
错误响应格式:
|
||||||
{
|
{
|
||||||
"detail": "错误描述"
|
"detail": "错误描述"
|
||||||
}
|
}
|
||||||
|
|
||||||
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
|
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
|
||||||
"""
|
"""
|
||||||
url = get_base_url()
|
url = get_base_url()
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n=== 测试错误处理 ===")
|
print("\n=== 测试错误处理 ===")
|
||||||
|
|
||||||
# 测试空消息列表
|
# 测试空消息列表
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试空消息列表 ---")
|
print("\n--- 测试空消息列表 ---")
|
||||||
@@ -469,7 +456,7 @@ def test_error_handling():
|
|||||||
response = make_request(url, data)
|
response = make_request(url, data)
|
||||||
print(f"状态码: {response.status_code}")
|
print(f"状态码: {response.status_code}")
|
||||||
print_json_response(response.json(), "错误信息")
|
print_json_response(response.json(), "错误信息")
|
||||||
|
|
||||||
# 测试无效角色字段
|
# 测试无效角色字段
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n--- 测试无效角色字段 ---")
|
print("\n--- 测试无效角色字段 ---")
|
||||||
@@ -490,7 +477,7 @@ def test_error_handling():
|
|||||||
|
|
||||||
def get_test_cases() -> Dict[str, Callable]:
|
def get_test_cases() -> Dict[str, Callable]:
|
||||||
"""获取所有可用的测试用例
|
"""获取所有可用的测试用例
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
测试名称到测试函数的映射字典
|
测试名称到测试函数的映射字典
|
||||||
"""
|
"""
|
||||||
@@ -564,21 +551,21 @@ def parse_args() -> argparse.Namespace:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 设置输出模式
|
# 设置输出模式
|
||||||
OutputControl.set_verbose(not args.quiet)
|
OutputControl.set_verbose(not args.quiet)
|
||||||
|
|
||||||
# 如果指定了查询内容,更新配置
|
# 如果指定了查询内容,更新配置
|
||||||
if args.ask:
|
if args.ask:
|
||||||
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
||||||
|
|
||||||
# 如果指定了创建配置文件
|
# 如果指定了创建配置文件
|
||||||
if args.init_config:
|
if args.init_config:
|
||||||
create_default_config()
|
create_default_config()
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
test_cases = get_test_cases()
|
test_cases = get_test_cases()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "all" in args.tests:
|
if "all" in args.tests:
|
||||||
# 运行所有测试
|
# 运行所有测试
|
||||||
@@ -586,11 +573,11 @@ if __name__ == "__main__":
|
|||||||
print("\n【基本功能测试】")
|
print("\n【基本功能测试】")
|
||||||
run_test(test_non_stream_chat, "非流式调用测试")
|
run_test(test_non_stream_chat, "非流式调用测试")
|
||||||
run_test(test_stream_chat, "流式调用测试")
|
run_test(test_stream_chat, "流式调用测试")
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n【查询模式测试】")
|
print("\n【查询模式测试】")
|
||||||
run_test(test_query_modes, "查询模式测试")
|
run_test(test_query_modes, "查询模式测试")
|
||||||
|
|
||||||
if OutputControl.is_verbose():
|
if OutputControl.is_verbose():
|
||||||
print("\n【错误处理测试】")
|
print("\n【错误处理测试】")
|
||||||
run_test(test_error_handling, "错误处理测试")
|
run_test(test_error_handling, "错误处理测试")
|
||||||
|
Reference in New Issue
Block a user