Add generate API tests and enhance chat API tests
- Add non-streaming generate API test - Add streaming generate API test - Add generate error handling tests - Add generate performance stats test - Add generate concurrent request test
This commit is contained in:
@@ -108,7 +108,10 @@ DEFAULT_CONFIG = {
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1,
|
||||
},
|
||||
"test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
|
||||
"test_cases": {
|
||||
"basic": {"query": "唐僧有几个徒弟"},
|
||||
"generate": {"query": "电视剧西游记导演是谁"}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
|
||||
CONFIG = load_config()
|
||||
|
||||
|
||||
def get_base_url() -> str:
|
||||
"""Return the base URL"""
|
||||
def get_base_url(endpoint: str = "chat") -> str:
|
||||
"""Return the base URL for specified endpoint
|
||||
Args:
|
||||
endpoint: API endpoint name (chat or generate)
|
||||
Returns:
|
||||
Complete URL for the endpoint
|
||||
"""
|
||||
server = CONFIG["server"]
|
||||
return f"http://{server['host']}:{server['port']}/api/chat"
|
||||
return f"http://{server['host']}:{server['port']}/api/{endpoint}"
|
||||
|
||||
|
||||
def create_request_data(
|
||||
def create_chat_request_data(
|
||||
content: str, stream: bool = False, model: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create basic request data
|
||||
"""Create chat request data
|
||||
Args:
|
||||
content: User message content
|
||||
stream: Whether to use streaming response
|
||||
model: Model name
|
||||
Returns:
|
||||
Dictionary containing complete request data
|
||||
Dictionary containing complete chat request data
|
||||
"""
|
||||
return {
|
||||
"model": model or CONFIG["server"]["model"],
|
||||
@@ -197,6 +205,34 @@ def create_request_data(
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
def create_generate_request_data(
|
||||
prompt: str,
|
||||
system: str = None,
|
||||
stream: bool = False,
|
||||
model: str = None,
|
||||
options: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create generate request data
|
||||
Args:
|
||||
prompt: Generation prompt
|
||||
system: System prompt
|
||||
stream: Whether to use streaming response
|
||||
model: Model name
|
||||
options: Additional options
|
||||
Returns:
|
||||
Dictionary containing complete generate request data
|
||||
"""
|
||||
data = {
|
||||
"model": model or CONFIG["server"]["model"],
|
||||
"prompt": prompt,
|
||||
"stream": stream
|
||||
}
|
||||
if system:
|
||||
data["system"] = system
|
||||
if options:
|
||||
data["options"] = options
|
||||
return data
|
||||
|
||||
|
||||
# Global test statistics
|
||||
STATS = TestStats()
|
||||
@@ -219,10 +255,10 @@ def run_test(func: Callable, name: str) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def test_non_stream_chat():
|
||||
def test_non_stream_chat() -> None:
|
||||
"""Test non-streaming call to /api/chat endpoint"""
|
||||
url = get_base_url()
|
||||
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
|
||||
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
|
||||
|
||||
# Send request
|
||||
response = make_request(url, data)
|
||||
@@ -239,7 +275,7 @@ def test_non_stream_chat():
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat():
|
||||
def test_stream_chat() -> None:
|
||||
"""Test streaming call to /api/chat endpoint
|
||||
|
||||
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
||||
@@ -258,7 +294,7 @@ def test_stream_chat():
|
||||
The last message will contain performance statistics, with done set to true.
|
||||
"""
|
||||
url = get_base_url()
|
||||
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
|
||||
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
|
||||
|
||||
# Send request and get streaming response
|
||||
response = make_request(url, data, stream=True)
|
||||
@@ -295,7 +331,7 @@ def test_stream_chat():
|
||||
print()
|
||||
|
||||
|
||||
def test_query_modes():
|
||||
def test_query_modes() -> None:
|
||||
"""Test different query mode prefixes
|
||||
|
||||
Supported query modes:
|
||||
@@ -313,7 +349,7 @@ def test_query_modes():
|
||||
for mode in modes:
|
||||
if OutputControl.is_verbose():
|
||||
print(f"\n=== Testing /{mode} mode ===")
|
||||
data = create_request_data(
|
||||
data = create_chat_request_data(
|
||||
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
|
||||
)
|
||||
|
||||
@@ -354,7 +390,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||
return error_data.get(error_type, error_data["empty_messages"])
|
||||
|
||||
|
||||
def test_stream_error_handling():
|
||||
def test_stream_error_handling() -> None:
|
||||
"""Test error handling for streaming responses
|
||||
|
||||
Test scenarios:
|
||||
@@ -400,7 +436,7 @@ def test_stream_error_handling():
|
||||
response.close()
|
||||
|
||||
|
||||
def test_error_handling():
|
||||
def test_error_handling() -> None:
|
||||
"""Test error handling for non-streaming responses
|
||||
|
||||
Test scenarios:
|
||||
@@ -447,6 +483,228 @@ def test_error_handling():
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
|
||||
def test_non_stream_generate() -> None:
|
||||
"""Test non-streaming call to /api/generate endpoint"""
|
||||
url = get_base_url("generate")
|
||||
data = create_generate_request_data(
|
||||
CONFIG["test_cases"]["generate"]["query"],
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Send request
|
||||
response = make_request(url, data)
|
||||
|
||||
# Print response
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== Non-streaming generate response ===")
|
||||
response_json = response.json()
|
||||
|
||||
# Print response content
|
||||
print_json_response(
|
||||
{
|
||||
"model": response_json["model"],
|
||||
"response": response_json["response"],
|
||||
"done": response_json["done"]
|
||||
},
|
||||
"Response content"
|
||||
)
|
||||
|
||||
def test_stream_generate() -> None:
|
||||
"""Test streaming call to /api/generate endpoint"""
|
||||
url = get_base_url("generate")
|
||||
data = create_generate_request_data(
|
||||
CONFIG["test_cases"]["generate"]["query"],
|
||||
stream=True
|
||||
)
|
||||
|
||||
# Send request and get streaming response
|
||||
response = make_request(url, data, stream=True)
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== Streaming generate response ===")
|
||||
output_buffer = []
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
if line: # Skip empty lines
|
||||
try:
|
||||
# Decode and parse JSON
|
||||
data = json.loads(line.decode("utf-8"))
|
||||
if data.get("done", True): # If it's the completion marker
|
||||
if "total_duration" in data: # Final performance statistics message
|
||||
break
|
||||
else: # Normal content message
|
||||
content = data.get("response", "")
|
||||
if content: # Only collect non-empty content
|
||||
output_buffer.append(content)
|
||||
print(content, end="", flush=True) # Print content in real-time
|
||||
except json.JSONDecodeError:
|
||||
print("Error decoding JSON from response line")
|
||||
finally:
|
||||
response.close() # Ensure the response connection is closed
|
||||
|
||||
# Print a newline
|
||||
print()
|
||||
|
||||
def test_generate_with_system() -> None:
|
||||
"""Test generate with system prompt"""
|
||||
url = get_base_url("generate")
|
||||
data = create_generate_request_data(
|
||||
CONFIG["test_cases"]["generate"]["query"],
|
||||
system="你是一个知识渊博的助手",
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Send request
|
||||
response = make_request(url, data)
|
||||
|
||||
# Print response
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== Generate with system prompt response ===")
|
||||
response_json = response.json()
|
||||
|
||||
# Print response content
|
||||
print_json_response(
|
||||
{
|
||||
"model": response_json["model"],
|
||||
"response": response_json["response"],
|
||||
"done": response_json["done"]
|
||||
},
|
||||
"Response content"
|
||||
)
|
||||
|
||||
def test_generate_error_handling() -> None:
|
||||
"""Test error handling for generate endpoint"""
|
||||
url = get_base_url("generate")
|
||||
|
||||
# Test empty prompt
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== Testing empty prompt ===")
|
||||
data = create_generate_request_data("", stream=False)
|
||||
response = make_request(url, data)
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
# Test invalid options
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== Testing invalid options ===")
|
||||
data = create_generate_request_data(
|
||||
CONFIG["test_cases"]["basic"]["query"],
|
||||
options={"invalid_option": "value"},
|
||||
stream=False
|
||||
)
|
||||
response = make_request(url, data)
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
# Test very long input
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== Testing very long input ===")
|
||||
long_text = "测试" * 10000 # Create a very long input
|
||||
data = create_generate_request_data(long_text, stream=False)
|
||||
response = make_request(url, data)
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
def test_generate_performance_stats() -> None:
|
||||
"""Test performance statistics in generate response"""
|
||||
url = get_base_url("generate")
|
||||
|
||||
# Test with different length inputs to verify token counting
|
||||
inputs = [
|
||||
"你好", # Short Chinese
|
||||
"Hello world", # Short English
|
||||
"这是一个较长的中文输入,用来测试token数量的估算是否准确。", # Medium Chinese
|
||||
"This is a longer English input that will be used to test the accuracy of token count estimation." # Medium English
|
||||
]
|
||||
|
||||
for test_input in inputs:
|
||||
if OutputControl.is_verbose():
|
||||
print(f"\n=== Testing performance stats with input: {test_input} ===")
|
||||
data = create_generate_request_data(test_input, stream=False)
|
||||
response = make_request(url, data)
|
||||
response_json = response.json()
|
||||
|
||||
# Verify performance statistics exist and are reasonable
|
||||
stats = {
|
||||
"total_duration": response_json.get("total_duration"),
|
||||
"prompt_eval_count": response_json.get("prompt_eval_count"),
|
||||
"prompt_eval_duration": response_json.get("prompt_eval_duration"),
|
||||
"eval_count": response_json.get("eval_count"),
|
||||
"eval_duration": response_json.get("eval_duration")
|
||||
}
|
||||
print_json_response(stats, "Performance statistics")
|
||||
|
||||
def test_generate_concurrent() -> None:
|
||||
"""Test concurrent generate requests"""
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
yield session
|
||||
|
||||
async def make_request(session, prompt: str):
|
||||
url = get_base_url("generate")
|
||||
data = create_generate_request_data(prompt, stream=False)
|
||||
try:
|
||||
async with session.post(url, json=data) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
async def run_concurrent_requests():
|
||||
prompts = [
|
||||
"第一个问题",
|
||||
"第二个问题",
|
||||
"第三个问题",
|
||||
"第四个问题",
|
||||
"第五个问题"
|
||||
]
|
||||
|
||||
async with get_session() as session:
|
||||
tasks = [make_request(session, prompt) for prompt in prompts]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n=== Testing concurrent generate requests ===")
|
||||
|
||||
# Run concurrent requests
|
||||
results = asyncio.run(run_concurrent_requests())
|
||||
|
||||
# Print results
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\nRequest {i} result:")
|
||||
print_json_response(result)
|
||||
|
||||
def test_generate_query_modes() -> None:
|
||||
"""Test different query mode prefixes for generate endpoint"""
|
||||
url = get_base_url("generate")
|
||||
modes = ["local", "global", "naive", "hybrid", "mix"]
|
||||
|
||||
for mode in modes:
|
||||
if OutputControl.is_verbose():
|
||||
print(f"\n=== Testing /{mode} mode for generate ===")
|
||||
data = create_generate_request_data(
|
||||
f"/{mode} {CONFIG['test_cases']['generate']['query']}",
|
||||
stream=False
|
||||
)
|
||||
|
||||
# Send request
|
||||
response = make_request(url, data)
|
||||
response_json = response.json()
|
||||
|
||||
# Print response content
|
||||
print_json_response(
|
||||
{
|
||||
"model": response_json["model"],
|
||||
"response": response_json["response"],
|
||||
"done": response_json["done"]
|
||||
}
|
||||
)
|
||||
|
||||
def get_test_cases() -> Dict[str, Callable]:
|
||||
"""Get all available test cases
|
||||
Returns:
|
||||
@@ -458,6 +716,13 @@ def get_test_cases() -> Dict[str, Callable]:
|
||||
"modes": test_query_modes,
|
||||
"errors": test_error_handling,
|
||||
"stream_errors": test_stream_error_handling,
|
||||
"non_stream_generate": test_non_stream_generate,
|
||||
"stream_generate": test_stream_generate,
|
||||
"generate_with_system": test_generate_with_system,
|
||||
"generate_modes": test_generate_query_modes,
|
||||
"generate_errors": test_generate_error_handling,
|
||||
"generate_stats": test_generate_performance_stats,
|
||||
"generate_concurrent": test_generate_concurrent
|
||||
}
|
||||
|
||||
|
||||
@@ -544,18 +809,22 @@ if __name__ == "__main__":
|
||||
if "all" in args.tests:
|
||||
# Run all tests
|
||||
if OutputControl.is_verbose():
|
||||
print("\n【Basic Functionality Tests】")
|
||||
run_test(test_non_stream_chat, "Non-streaming Call Test")
|
||||
run_test(test_stream_chat, "Streaming Call Test")
|
||||
print("\n【Chat API Tests】")
|
||||
run_test(test_non_stream_chat, "Non-streaming Chat Test")
|
||||
run_test(test_stream_chat, "Streaming Chat Test")
|
||||
run_test(test_query_modes, "Chat Query Mode Test")
|
||||
run_test(test_error_handling, "Chat Error Handling Test")
|
||||
run_test(test_stream_error_handling, "Chat Streaming Error Test")
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n【Query Mode Tests】")
|
||||
run_test(test_query_modes, "Query Mode Test")
|
||||
|
||||
if OutputControl.is_verbose():
|
||||
print("\n【Error Handling Tests】")
|
||||
run_test(test_error_handling, "Error Handling Test")
|
||||
run_test(test_stream_error_handling, "Streaming Error Handling Test")
|
||||
print("\n【Generate API Tests】")
|
||||
run_test(test_non_stream_generate, "Non-streaming Generate Test")
|
||||
run_test(test_stream_generate, "Streaming Generate Test")
|
||||
run_test(test_generate_with_system, "Generate with System Prompt Test")
|
||||
run_test(test_generate_query_modes, "Generate Query Mode Test")
|
||||
run_test(test_generate_error_handling, "Generate Error Handling Test")
|
||||
run_test(test_generate_performance_stats, "Generate Performance Stats Test")
|
||||
run_test(test_generate_concurrent, "Generate Concurrent Test")
|
||||
else:
|
||||
# Run specified tests
|
||||
for test_name in args.tests:
|
||||
|
Reference in New Issue
Block a user