Add generate API tests and enhance chat API tests

- Add non-streaming generate API test
- Add streaming generate API test
- Add generate error handling tests
- Add generate performance stats test
- Add generate concurrent request test
This commit is contained in:
yangdx
2025-01-24 19:09:31 +08:00
parent 2c8885792c
commit c26d799bb6

View File

@@ -108,7 +108,10 @@ DEFAULT_CONFIG = {
"max_retries": 3,
"retry_delay": 1,
},
"test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
"test_cases": {
"basic": {"query": "唐僧有几个徒弟"},
"generate": {"query": "电视剧西游记导演是谁"}
},
}
@@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
CONFIG = load_config()
def get_base_url() -> str:
"""Return the base URL"""
def get_base_url(endpoint: str = "chat") -> str:
"""Return the base URL for specified endpoint
Args:
endpoint: API endpoint name (chat or generate)
Returns:
Complete URL for the endpoint
"""
server = CONFIG["server"]
return f"http://{server['host']}:{server['port']}/api/chat"
return f"http://{server['host']}:{server['port']}/api/{endpoint}"
def create_request_data(
def create_chat_request_data(
content: str, stream: bool = False, model: str = None
) -> Dict[str, Any]:
"""Create basic request data
"""Create chat request data
Args:
content: User message content
stream: Whether to use streaming response
model: Model name
Returns:
Dictionary containing complete request data
Dictionary containing complete chat request data
"""
return {
"model": model or CONFIG["server"]["model"],
@@ -197,6 +205,34 @@ def create_request_data(
"stream": stream,
}
def create_generate_request_data(
prompt: str,
system: str = None,
stream: bool = False,
model: str = None,
options: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Create generate request data
Args:
prompt: Generation prompt
system: System prompt
stream: Whether to use streaming response
model: Model name
options: Additional options
Returns:
Dictionary containing complete generate request data
"""
data = {
"model": model or CONFIG["server"]["model"],
"prompt": prompt,
"stream": stream
}
if system:
data["system"] = system
if options:
data["options"] = options
return data
# Global test statistics
STATS = TestStats()
@@ -219,10 +255,10 @@ def run_test(func: Callable, name: str) -> None:
raise
def test_non_stream_chat():
def test_non_stream_chat() -> None:
"""Test non-streaming call to /api/chat endpoint"""
url = get_base_url()
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
# Send request
response = make_request(url, data)
@@ -239,7 +275,7 @@ def test_non_stream_chat():
)
def test_stream_chat():
def test_stream_chat() -> None:
"""Test streaming call to /api/chat endpoint
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
@@ -258,7 +294,7 @@ def test_stream_chat():
The last message will contain performance statistics, with done set to true.
"""
url = get_base_url()
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
# Send request and get streaming response
response = make_request(url, data, stream=True)
@@ -295,7 +331,7 @@ def test_stream_chat():
print()
def test_query_modes():
def test_query_modes() -> None:
"""Test different query mode prefixes
Supported query modes:
@@ -313,7 +349,7 @@ def test_query_modes():
for mode in modes:
if OutputControl.is_verbose():
print(f"\n=== Testing /{mode} mode ===")
data = create_request_data(
data = create_chat_request_data(
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
)
@@ -354,7 +390,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
return error_data.get(error_type, error_data["empty_messages"])
def test_stream_error_handling():
def test_stream_error_handling() -> None:
"""Test error handling for streaming responses
Test scenarios:
@@ -400,7 +436,7 @@ def test_stream_error_handling():
response.close()
def test_error_handling():
def test_error_handling() -> None:
"""Test error handling for non-streaming responses
Test scenarios:
@@ -447,6 +483,228 @@ def test_error_handling():
print_json_response(response.json(), "Error message")
def test_non_stream_generate() -> None:
"""Test non-streaming call to /api/generate endpoint"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"],
stream=False
)
# Send request
response = make_request(url, data)
# Print response
if OutputControl.is_verbose():
print("\n=== Non-streaming generate response ===")
response_json = response.json()
# Print response content
print_json_response(
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"]
},
"Response content"
)
def test_stream_generate() -> None:
"""Test streaming call to /api/generate endpoint"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"],
stream=True
)
# Send request and get streaming response
response = make_request(url, data, stream=True)
if OutputControl.is_verbose():
print("\n=== Streaming generate response ===")
output_buffer = []
try:
for line in response.iter_lines():
if line: # Skip empty lines
try:
# Decode and parse JSON
data = json.loads(line.decode("utf-8"))
if data.get("done", True): # If it's the completion marker
if "total_duration" in data: # Final performance statistics message
break
else: # Normal content message
content = data.get("response", "")
if content: # Only collect non-empty content
output_buffer.append(content)
print(content, end="", flush=True) # Print content in real-time
except json.JSONDecodeError:
print("Error decoding JSON from response line")
finally:
response.close() # Ensure the response connection is closed
# Print a newline
print()
def test_generate_with_system() -> None:
"""Test generate with system prompt"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"],
system="你是一个知识渊博的助手",
stream=False
)
# Send request
response = make_request(url, data)
# Print response
if OutputControl.is_verbose():
print("\n=== Generate with system prompt response ===")
response_json = response.json()
# Print response content
print_json_response(
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"]
},
"Response content"
)
def test_generate_error_handling() -> None:
"""Test error handling for generate endpoint"""
url = get_base_url("generate")
# Test empty prompt
if OutputControl.is_verbose():
print("\n=== Testing empty prompt ===")
data = create_generate_request_data("", stream=False)
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# Test invalid options
if OutputControl.is_verbose():
print("\n=== Testing invalid options ===")
data = create_generate_request_data(
CONFIG["test_cases"]["basic"]["query"],
options={"invalid_option": "value"},
stream=False
)
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# Test very long input
if OutputControl.is_verbose():
print("\n=== Testing very long input ===")
long_text = "测试" * 10000 # Create a very long input
data = create_generate_request_data(long_text, stream=False)
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
def test_generate_performance_stats() -> None:
"""Test performance statistics in generate response"""
url = get_base_url("generate")
# Test with different length inputs to verify token counting
inputs = [
"你好", # Short Chinese
"Hello world", # Short English
"这是一个较长的中文输入用来测试token数量的估算是否准确。", # Medium Chinese
"This is a longer English input that will be used to test the accuracy of token count estimation." # Medium English
]
for test_input in inputs:
if OutputControl.is_verbose():
print(f"\n=== Testing performance stats with input: {test_input} ===")
data = create_generate_request_data(test_input, stream=False)
response = make_request(url, data)
response_json = response.json()
# Verify performance statistics exist and are reasonable
stats = {
"total_duration": response_json.get("total_duration"),
"prompt_eval_count": response_json.get("prompt_eval_count"),
"prompt_eval_duration": response_json.get("prompt_eval_duration"),
"eval_count": response_json.get("eval_count"),
"eval_duration": response_json.get("eval_duration")
}
print_json_response(stats, "Performance statistics")
def test_generate_concurrent() -> None:
"""Test concurrent generate requests"""
import asyncio
import aiohttp
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_session():
async with aiohttp.ClientSession() as session:
yield session
async def make_request(session, prompt: str):
url = get_base_url("generate")
data = create_generate_request_data(prompt, stream=False)
try:
async with session.post(url, json=data) as response:
return await response.json()
except Exception as e:
return {"error": str(e)}
async def run_concurrent_requests():
prompts = [
"第一个问题",
"第二个问题",
"第三个问题",
"第四个问题",
"第五个问题"
]
async with get_session() as session:
tasks = [make_request(session, prompt) for prompt in prompts]
results = await asyncio.gather(*tasks)
return results
if OutputControl.is_verbose():
print("\n=== Testing concurrent generate requests ===")
# Run concurrent requests
results = asyncio.run(run_concurrent_requests())
# Print results
for i, result in enumerate(results, 1):
print(f"\nRequest {i} result:")
print_json_response(result)
def test_generate_query_modes() -> None:
"""Test different query mode prefixes for generate endpoint"""
url = get_base_url("generate")
modes = ["local", "global", "naive", "hybrid", "mix"]
for mode in modes:
if OutputControl.is_verbose():
print(f"\n=== Testing /{mode} mode for generate ===")
data = create_generate_request_data(
f"/{mode} {CONFIG['test_cases']['generate']['query']}",
stream=False
)
# Send request
response = make_request(url, data)
response_json = response.json()
# Print response content
print_json_response(
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"]
}
)
def get_test_cases() -> Dict[str, Callable]:
"""Get all available test cases
Returns:
@@ -458,6 +716,13 @@ def get_test_cases() -> Dict[str, Callable]:
"modes": test_query_modes,
"errors": test_error_handling,
"stream_errors": test_stream_error_handling,
"non_stream_generate": test_non_stream_generate,
"stream_generate": test_stream_generate,
"generate_with_system": test_generate_with_system,
"generate_modes": test_generate_query_modes,
"generate_errors": test_generate_error_handling,
"generate_stats": test_generate_performance_stats,
"generate_concurrent": test_generate_concurrent
}
@@ -544,18 +809,22 @@ if __name__ == "__main__":
if "all" in args.tests:
# Run all tests
if OutputControl.is_verbose():
print("\nBasic Functionality Tests】")
run_test(test_non_stream_chat, "Non-streaming Call Test")
run_test(test_stream_chat, "Streaming Call Test")
print("\nChat API Tests】")
run_test(test_non_stream_chat, "Non-streaming Chat Test")
run_test(test_stream_chat, "Streaming Chat Test")
run_test(test_query_modes, "Chat Query Mode Test")
run_test(test_error_handling, "Chat Error Handling Test")
run_test(test_stream_error_handling, "Chat Streaming Error Test")
if OutputControl.is_verbose():
print("\nQuery Mode Tests】")
run_test(test_query_modes, "Query Mode Test")
if OutputControl.is_verbose():
print("\n【Error Handling Tests】")
run_test(test_error_handling, "Error Handling Test")
run_test(test_stream_error_handling, "Streaming Error Handling Test")
print("\nGenerate API Tests】")
run_test(test_non_stream_generate, "Non-streaming Generate Test")
run_test(test_stream_generate, "Streaming Generate Test")
run_test(test_generate_with_system, "Generate with System Prompt Test")
run_test(test_generate_query_modes, "Generate Query Mode Test")
run_test(test_generate_error_handling, "Generate Error Handling Test")
run_test(test_generate_performance_stats, "Generate Performance Stats Test")
run_test(test_generate_concurrent, "Generate Concurrent Test")
else:
# Run specified tests
for test_name in args.tests: