diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index d17d50fc..08ecad67 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -19,18 +19,23 @@ from datetime import datetime from pathlib import Path from enum import Enum, auto + class ErrorCode(Enum): """Error codes for MCP errors""" + InvalidRequest = auto() InternalError = auto() + class McpError(Exception): """Base exception class for MCP errors""" + def __init__(self, code: ErrorCode, message: str): self.code = code self.message = message super().__init__(message) + DEFAULT_CONFIG = { "server": { "host": "localhost", @@ -660,13 +665,17 @@ def test_generate_concurrent() -> None: try: async with session.post(url, json=data) as response: if response.status != 200: - error_msg = f"Request {request_id} failed with status {response.status}" + error_msg = ( + f"Request {request_id} failed with status {response.status}" + ) if OutputControl.is_verbose(): print(f"\n{error_msg}") raise McpError(ErrorCode.InternalError, error_msg) result = await response.json() if "error" in result: - error_msg = f"Request {request_id} returned error: {result['error']}" + error_msg = ( + f"Request {request_id} returned error: {result['error']}" + ) if OutputControl.is_verbose(): print(f"\n{error_msg}") raise McpError(ErrorCode.InternalError, error_msg) @@ -679,21 +688,23 @@ def test_generate_concurrent() -> None: async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] - + async with get_session() as session: - tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)] + tasks = [ + make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts) + ] results = await asyncio.gather(*tasks, return_exceptions=True) - + # 收集成功和失败的结果 success_results = [] error_messages = [] - + for i, result in enumerate(results): if isinstance(result, Exception): error_messages.append(f"Request {i+1} failed: {str(result)}") else: - success_results.append((i+1, result)) - + success_results.append((i + 1, result)) + # 如果有任何错误,在打印完所有结果后抛出异常 if error_messages: # 先打印成功的结果 @@ -701,14 +712,14 @@ def test_generate_concurrent() -> None: if OutputControl.is_verbose(): print(f"\nRequest {req_id} succeeded:") print_json_response(result) - + # 打印所有错误信息 error_summary = "\n".join(error_messages) raise McpError( ErrorCode.InternalError, - f"Some concurrent requests failed:\n{error_summary}" + f"Some concurrent requests failed:\n{error_summary}", ) - + return results if OutputControl.is_verbose(): @@ -721,7 +732,7 @@ def test_generate_concurrent() -> None: for i, result in enumerate(results, 1): print(f"\nRequest {i} result:") print_json_response(result) - except McpError as e: + except McpError: # 错误信息已经在之前打印过了,这里直接抛出 raise