diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 8cb633ba..d17d50fc 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -17,6 +17,19 @@ from typing import Dict, Any, Optional, List, Callable from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path +from enum import Enum, auto + +class ErrorCode(Enum): + """Error codes for MCP errors""" + InvalidRequest = auto() + InternalError = auto() + +class McpError(Exception): + """Base exception class for MCP errors""" + def __init__(self, code: ErrorCode, message: str): + self.code = code + self.message = message + super().__init__(message) DEFAULT_CONFIG = { "server": { @@ -641,35 +654,76 @@ def test_generate_concurrent() -> None: async with aiohttp.ClientSession() as session: yield session - async def make_request(session, prompt: str): + async def make_request(session, prompt: str, request_id: int): url = get_base_url("generate") data = create_generate_request_data(prompt, stream=False) try: async with session.post(url, json=data) as response: if response.status != 200: - response.raise_for_status() - return await response.json() + error_msg = f"Request {request_id} failed with status {response.status}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + result = await response.json() + if "error" in result: + error_msg = f"Request {request_id} returned error: {result['error']}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + return result except Exception as e: - return {"error": str(e)} + error_msg = f"Request {request_id} failed: {str(e)}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] - + async with get_session() as session: - tasks = [make_request(session, prompt) for prompt in prompts] - results = await asyncio.gather(*tasks) + tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 收集成功和失败的结果 + success_results = [] + error_messages = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + error_messages.append(f"Request {i+1} failed: {str(result)}") + else: + success_results.append((i+1, result)) + + # 如果有任何错误,在打印完所有结果后抛出异常 + if error_messages: + # 先打印成功的结果 + for req_id, result in success_results: + if OutputControl.is_verbose(): + print(f"\nRequest {req_id} succeeded:") + print_json_response(result) + + # 打印所有错误信息 + error_summary = "\n".join(error_messages) + raise McpError( + ErrorCode.InternalError, + f"Some concurrent requests failed:\n{error_summary}" + ) + return results if OutputControl.is_verbose(): print("\n=== Testing concurrent generate requests ===") # Run concurrent requests - results = asyncio.run(run_concurrent_requests()) - - # Print results - for i, result in enumerate(results, 1): - print(f"\nRequest {i} result:") - print_json_response(result) + try: + results = asyncio.run(run_concurrent_requests()) + # 如果没有异常,打印所有成功的结果 + for i, result in enumerate(results, 1): + print(f"\nRequest {i} result:") + print_json_response(result) + except McpError as e: + # 错误信息已经在之前打印过了,这里直接抛出 + raise def get_test_cases() -> Dict[str, Callable]: