pre-commit run --all-files
This commit is contained in:
@@ -18,8 +18,10 @@ from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class OutputControl:
|
||||
"""Output control class, manages the verbosity of test output"""
|
||||
|
||||
_verbose: bool = False
|
||||
|
||||
@classmethod
|
||||
@@ -30,9 +32,11 @@ class OutputControl:
|
||||
def is_verbose(cls) -> bool:
|
||||
return cls._verbose
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
"""Test result data class"""
|
||||
|
||||
name: str
|
||||
success: bool
|
||||
duration: float
|
||||
@@ -43,8 +47,10 @@ class TestResult:
|
||||
if not self.timestamp:
|
||||
self.timestamp = datetime.now().isoformat()
|
||||
|
||||
|
||||
class TestStats:
|
||||
"""Test statistics"""
|
||||
|
||||
def __init__(self):
|
||||
self.results: List[TestResult] = []
|
||||
self.start_time = datetime.now()
|
||||
@@ -65,8 +71,8 @@ class TestStats:
|
||||
"total": len(self.results),
|
||||
"passed": sum(1 for r in self.results if r.success),
|
||||
"failed": sum(1 for r in self.results if not r.success),
|
||||
"total_duration": sum(r.duration for r in self.results)
|
||||
}
|
||||
"total_duration": sum(r.duration for r in self.results),
|
||||
},
|
||||
}
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
@@ -92,6 +98,7 @@ class TestStats:
|
||||
if not result.success:
|
||||
print(f"- {result.name}: {result.error}")
|
||||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"server": {
|
||||
"host": "localhost",
|
||||
@@ -99,16 +106,15 @@ DEFAULT_CONFIG = {
|
||||
"model": "lightrag:latest",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 1
|
||||
"retry_delay": 1,
|
||||
},
|
||||
"test_cases": {
|
||||
"basic": {
|
||||
"query": "唐僧有几个徒弟"
|
||||
}
|
||||
}
|
||||
"test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
|
||||
}
|
||||
|
||||
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
||||
|
||||
def make_request(
|
||||
url: str, data: Dict[str, Any], stream: bool = False
|
||||
) -> requests.Response:
|
||||
"""Send an HTTP request with retry mechanism
|
||||
Args:
|
||||
url: Request URL
|
||||
@@ -127,12 +133,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json=data,
|
||||
stream=stream,
|
||||
timeout=timeout
|
||||
)
|
||||
response = requests.post(url, json=data, stream=stream, timeout=timeout)
|
||||
return response
|
||||
except requests.exceptions.RequestException as e:
|
||||
if attempt == max_retries - 1: # Last retry
|
||||
@@ -140,6 +141,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
||||
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
|
||||
time.sleep(retry_delay)
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
"""Load configuration file
|
||||
|
||||
@@ -154,6 +156,7 @@ def load_config() -> Dict[str, Any]:
|
||||
return json.load(f)
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
|
||||
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
||||
"""Format and print JSON response data
|
||||
Args:
|
||||
@@ -166,18 +169,19 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
|
||||
print(f"\n=== {title} ===")
|
||||
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
||||
|
||||
|
||||
# Global configuration
|
||||
CONFIG = load_config()
|
||||
|
||||
|
||||
def get_base_url() -> str:
|
||||
"""Return the base URL"""
|
||||
server = CONFIG["server"]
|
||||
return f"http://{server['host']}:{server['port']}/api/chat"
|
||||
|
||||
|
||||
def create_request_data(
|
||||
content: str,
|
||||
stream: bool = False,
|
||||
model: str = None
|
||||
content: str, stream: bool = False, model: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create basic request data
|
||||
Args:
|
||||
@@ -189,18 +193,15 @@ def create_request_data(
|
||||
"""
|
||||
return {
|
||||
"model": model or CONFIG["server"]["model"],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
"stream": stream
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
|
||||
# Global test statistics
|
||||
STATS = TestStats()
|
||||
|
||||
|
||||
def run_test(func: Callable, name: str) -> None:
|
||||
"""Run a test and record the results
|
||||
Args:
|
||||
@@ -217,13 +218,11 @@ def run_test(func: Callable, name: str) -> None:
|
||||
STATS.add_result(TestResult(name, False, duration, str(e)))
|
||||
raise
|
||||
|
||||
|
||||
def test_non_stream_chat():
|
||||
"""Test non-streaming call to /api/chat endpoint"""
|
||||
url = get_base_url()
|
||||
data = create_request_data(
|
||||
CONFIG["test_cases"]["basic"]["query"],
|
||||
stream=False
|
||||
)
|
||||
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
|
||||
|
||||
# Send request
|
||||
response = make_request(url, data)
|
||||
@@ -234,10 +233,12 @@ def test_non_stream_chat():
|
||||
response_json = response.json()
|
||||
|
||||
# Print response content
|
||||
print_json_response({
|
||||
"model": response_json["model"],
|
||||
"message": response_json["message"]
|
||||
}, "Response content")
|
||||
print_json_response(
|
||||
{"model": response_json["model"], "message": response_json["message"]},
|
||||
"Response content",
|
||||
)
|
||||
|
||||
|
||||
def test_stream_chat():
|
||||
"""Test streaming call to /api/chat endpoint
|
||||
|
||||
@@ -257,10 +258,7 @@ def test_stream_chat():
|
||||
The last message will contain performance statistics, with done set to true.
|
||||
"""
|
||||
url = get_base_url()
|
||||
data = create_request_data(
|
||||
CONFIG["test_cases"]["basic"]["query"],
|
||||
stream=True
|
||||
)
|
||||
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
|
||||
|
||||
# Send request and get streaming response
|
||||
response = make_request(url, data, stream=True)
|
||||
@@ -273,9 +271,11 @@ def test_stream_chat():
|
||||
if line: # Skip empty lines
|
||||
try:
|
||||
# Decode and parse JSON
|
||||
data = json.loads(line.decode('utf-8'))
|
||||
data = json.loads(line.decode("utf-8"))
|
||||
if data.get("done", True): # If it's the completion marker
|
||||
if "total_duration" in data: # Final performance statistics message
|
||||
if (
|
||||
"total_duration" in data
|
||||
): # Final performance statistics message
|
||||
# print_json_response(data, "Performance statistics")
|
||||
break
|
||||
else: # Normal content message
|
||||
@@ -283,7 +283,9 @@ def test_stream_chat():
|
||||
content = message.get("content", "")
|
||||
if content: # Only collect non-empty content
|
||||
output_buffer.append(content)
|
||||
print(content, end="", flush=True) # Print content in real-time
|
||||
print(
|
||||
content, end="", flush=True
|
||||
) # Print content in real-time
|
||||
except json.JSONDecodeError:
|
||||
print("Error decoding JSON from response line")
|
||||
finally:
|
||||
@@ -292,6 +294,7 @@ def test_stream_chat():
|
||||
# Print a newline
|
||||
print()
|
||||
|
||||
|
||||
def test_query_modes():
|
||||
"""Test different query mode prefixes
|
||||
|
||||
@@ -311,8 +314,7 @@ def test_query_modes():
|
||||
if OutputControl.is_verbose():
|
||||
print(f"\n=== Testing /{mode} mode ===")
|
||||
data = create_request_data(
|
||||
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
||||
stream=False
|
||||
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
|
||||
)
|
||||
|
||||
# Send request
|
||||
@@ -320,10 +322,10 @@ def test_query_modes():
|
||||
response_json = response.json()
|
||||
|
||||
# Print response content
|
||||
print_json_response({
|
||||
"model": response_json["model"],
|
||||
"message": response_json["message"]
|
||||
})
|
||||
print_json_response(
|
||||
{"model": response_json["model"], "message": response_json["message"]}
|
||||
)
|
||||
|
||||
|
||||
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||
"""Create request data for error testing
|
||||
@@ -337,33 +339,21 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||||
Request dictionary containing error data
|
||||
"""
|
||||
error_data = {
|
||||
"empty_messages": {
|
||||
"model": "lightrag:latest",
|
||||
"messages": [],
|
||||
"stream": True
|
||||
},
|
||||
"empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True},
|
||||
"invalid_role": {
|
||||
"model": "lightrag:latest",
|
||||
"messages": [
|
||||
{
|
||||
"invalid_role": "user",
|
||||
"content": "Test message"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
"messages": [{"invalid_role": "user", "content": "Test message"}],
|
||||
"stream": True,
|
||||
},
|
||||
"missing_content": {
|
||||
"model": "lightrag:latest",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
"messages": [{"role": "user"}],
|
||||
"stream": True,
|
||||
},
|
||||
}
|
||||
return error_data.get(error_type, error_data["empty_messages"])
|
||||
|
||||
|
||||
def test_stream_error_handling():
|
||||
"""Test error handling for streaming responses
|
||||
|
||||
@@ -409,6 +399,7 @@ def test_stream_error_handling():
|
||||
print_json_response(response.json(), "Error message")
|
||||
response.close()
|
||||
|
||||
|
||||
def test_error_handling():
|
||||
"""Test error handling for non-streaming responses
|
||||
|
||||
@@ -455,6 +446,7 @@ def test_error_handling():
|
||||
print(f"Status code: {response.status_code}")
|
||||
print_json_response(response.json(), "Error message")
|
||||
|
||||
|
||||
def get_test_cases() -> Dict[str, Callable]:
|
||||
"""Get all available test cases
|
||||
Returns:
|
||||
@@ -465,9 +457,10 @@ def get_test_cases() -> Dict[str, Callable]:
|
||||
"stream": test_stream_chat,
|
||||
"modes": test_query_modes,
|
||||
"errors": test_error_handling,
|
||||
"stream_errors": test_stream_error_handling
|
||||
"stream_errors": test_stream_error_handling,
|
||||
}
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default configuration file"""
|
||||
config_path = Path("config.json")
|
||||
@@ -476,6 +469,7 @@ def create_default_config():
|
||||
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
||||
print(f"Default configuration file created: {config_path}")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -496,38 +490,39 @@ Configuration file (config.json):
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q", "--quiet",
|
||||
"-q",
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help="Silent mode, only display test result summary"
|
||||
help="Silent mode, only display test result summary",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--ask",
|
||||
"-a",
|
||||
"--ask",
|
||||
type=str,
|
||||
help="Specify query content, which will override the query settings in the configuration file"
|
||||
help="Specify query content, which will override the query settings in the configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-config",
|
||||
action="store_true",
|
||||
help="Create default configuration file"
|
||||
"--init-config", action="store_true", help="Create default configuration file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="",
|
||||
help="Test result output file path, default is not to output to a file"
|
||||
help="Test result output file path, default is not to output to a file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests",
|
||||
nargs="+",
|
||||
choices=list(get_test_cases().keys()) + ["all"],
|
||||
default=["all"],
|
||||
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests"
|
||||
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
|
Reference in New Issue
Block a user