diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 23b09f55..edf97993 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -205,14 +205,14 @@ class OllamaAPI: async def stream_generator(): try: first_chunk_time = None - last_chunk_time = None + last_chunk_time = time.time_ns() total_response = "" # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = start_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -241,22 +241,50 @@ class OllamaAPI: } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() - last_chunk_time = time.time_ns() + last_chunk_time = time.time_ns() - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + logging.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": f"\n\nError: {error_msg}", + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Send final message to close the stream + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return + if first_chunk_time is None: + first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time @@ -381,16 +409,16 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = None - total_response = "" - try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = start_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -474,45 +502,29 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return - if last_chunk_time is not None: - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + if first_chunk_time is None: + first_chunk_time = start_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" except Exception as e: - error_msg = f"Error in stream_generator: {str(e)}" - logging.error(error_msg) - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": {"code": "STREAM_ERROR", "message": error_msg}, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Ensure sending end marker - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + trace_exception(e) + raise return StreamingResponse( stream_generator(), diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index cadd02cb..cfd7557f 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -17,6 +17,24 @@ from typing import Dict, Any, Optional, List, Callable from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path +from enum import Enum, auto + + +class ErrorCode(Enum): + """Error codes for MCP errors""" + + InvalidRequest = auto() + InternalError = auto() + + +class McpError(Exception): + """Base exception class for MCP errors""" + + def __init__(self, code: ErrorCode, message: str): + self.code = code + self.message = message + super().__init__(message) + DEFAULT_CONFIG = { "server": { @@ -634,35 +652,82 @@ def test_generate_concurrent() -> None: async with aiohttp.ClientSession() as session: yield session - async def make_request(session, prompt: str): + async def make_request(session, prompt: str, request_id: int): url = get_base_url("generate") data = create_generate_request_data(prompt, stream=False) try: async with session.post(url, json=data) as response: if response.status != 200: - response.raise_for_status() - return await response.json() + error_msg = ( + f"Request {request_id} failed with status {response.status}" + ) + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + result = await response.json() + if "error" in result: + error_msg = ( + f"Request {request_id} returned error: {result['error']}" + ) + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + return result except Exception as e: - return {"error": str(e)} + error_msg = f"Request {request_id} failed: {str(e)}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] async with get_session() as session: - tasks = [make_request(session, prompt) for prompt in prompts] - results = await asyncio.gather(*tasks) + tasks = [ + make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts) + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 收集成功和失败的结果 + success_results = [] + error_messages = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + error_messages.append(f"Request {i+1} failed: {str(result)}") + else: + success_results.append((i + 1, result)) + + # 如果有任何错误,在打印完所有结果后抛出异常 + if error_messages: + # 先打印成功的结果 + for req_id, result in success_results: + if OutputControl.is_verbose(): + print(f"\nRequest {req_id} succeeded:") + print_json_response(result) + + # 打印所有错误信息 + error_summary = "\n".join(error_messages) + raise McpError( + ErrorCode.InternalError, + f"Some concurrent requests failed:\n{error_summary}", + ) + return results if OutputControl.is_verbose(): print("\n=== Testing concurrent generate requests ===") # Run concurrent requests - results = asyncio.run(run_concurrent_requests()) - - # Print results - for i, result in enumerate(results, 1): - print(f"\nRequest {i} result:") - print_json_response(result) + try: + results = asyncio.run(run_concurrent_requests()) + # 如果没有异常,打印所有成功的结果 + for i, result in enumerate(results, 1): + print(f"\nRequest {i} result:") + print_json_response(result) + except McpError: + # 错误信息已经在之前打印过了,这里直接抛出 + raise def get_test_cases() -> Dict[str, Callable]: