Add error handling and improve logging for concurrent request testing

• Added McpError and ErrorCode classes
• Added detailed error collection logic
• Improved error reporting & formatting
• Added request ID tracking
• Enhanced test results visibility
This commit is contained in:
yangdx
2025-02-06 03:17:27 +08:00
committed by ultrageopro
parent 4bdb8e30de
commit e4476ef0af

View File

@@ -17,6 +17,19 @@ from typing import Dict, Any, Optional, List, Callable
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from enum import Enum, auto
class ErrorCode(Enum):
"""Error codes for MCP errors"""
InvalidRequest = auto()
InternalError = auto()
class McpError(Exception):
"""Base exception class for MCP errors"""
def __init__(self, code: ErrorCode, message: str):
self.code = code
self.message = message
super().__init__(message)
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"server": { "server": {
@@ -634,35 +647,76 @@ def test_generate_concurrent() -> None:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
yield session yield session
async def make_request(session, prompt: str): async def make_request(session, prompt: str, request_id: int):
url = get_base_url("generate") url = get_base_url("generate")
data = create_generate_request_data(prompt, stream=False) data = create_generate_request_data(prompt, stream=False)
try: try:
async with session.post(url, json=data) as response: async with session.post(url, json=data) as response:
if response.status != 200: if response.status != 200:
response.raise_for_status() error_msg = f"Request {request_id} failed with status {response.status}"
return await response.json() if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
result = await response.json()
if "error" in result:
error_msg = f"Request {request_id} returned error: {result['error']}"
if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
return result
except Exception as e: except Exception as e:
return {"error": str(e)} error_msg = f"Request {request_id} failed: {str(e)}"
if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
async def run_concurrent_requests(): async def run_concurrent_requests():
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
async with get_session() as session: async with get_session() as session:
tasks = [make_request(session, prompt) for prompt in prompts] tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks, return_exceptions=True)
# 收集成功和失败的结果
success_results = []
error_messages = []
for i, result in enumerate(results):
if isinstance(result, Exception):
error_messages.append(f"Request {i+1} failed: {str(result)}")
else:
success_results.append((i+1, result))
# 如果有任何错误,在打印完所有结果后抛出异常
if error_messages:
# 先打印成功的结果
for req_id, result in success_results:
if OutputControl.is_verbose():
print(f"\nRequest {req_id} succeeded:")
print_json_response(result)
# 打印所有错误信息
error_summary = "\n".join(error_messages)
raise McpError(
ErrorCode.InternalError,
f"Some concurrent requests failed:\n{error_summary}"
)
return results return results
if OutputControl.is_verbose(): if OutputControl.is_verbose():
print("\n=== Testing concurrent generate requests ===") print("\n=== Testing concurrent generate requests ===")
# Run concurrent requests # Run concurrent requests
try:
results = asyncio.run(run_concurrent_requests()) results = asyncio.run(run_concurrent_requests())
# 如果没有异常,打印所有成功的结果
# Print results
for i, result in enumerate(results, 1): for i, result in enumerate(results, 1):
print(f"\nRequest {i} result:") print(f"\nRequest {i} result:")
print_json_response(result) print_json_response(result)
except McpError as e:
# 错误信息已经在之前打印过了,这里直接抛出
raise
def get_test_cases() -> Dict[str, Callable]: def get_test_cases() -> Dict[str, Callable]: