From e26c6e564d52431d70d70d8a4740843feacac5f7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 02:43:06 +0800 Subject: [PATCH 1/6] refactor: enhance stream error handling and optimize code structure - Initialize timestamps at start to avoid null checks - Add detailed error handling for streaming response - Handle CancelledError and other exceptions separately - Unify exception handling with trace_exception - Clean up redundant code and simplify logic --- lightrag/api/ollama_api.py | 124 ++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 58 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 23b09f55..02b4b573 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,16 +203,16 @@ class OllamaAPI: ) async def stream_generator(): + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = "" + try: - first_chunk_time = None - last_chunk_time = None - total_response = "" # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -241,21 +241,48 @@ class OllamaAPI: } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() - last_chunk_time = time.time_ns() + last_chunk_time = time.time_ns() - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + + logging.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": f"\n\nError: {error_msg}", + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Send final message to close the stream + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time @@ -381,16 +408,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = None + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -474,45 +500,27 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return - if last_chunk_time is not None: - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" except Exception as e: - error_msg = f"Error in stream_generator: {str(e)}" - logging.error(error_msg) - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": {"code": "STREAM_ERROR", "message": error_msg}, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Ensure sending end marker - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + trace_exception(e) + raise return StreamingResponse( stream_generator(), From 65dc0a6cfd47746d2a77799b51be3eb87ef932bc Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 02:50:27 +0800 Subject: [PATCH 2/6] Fix linting --- lightrag/api/ollama_api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 02b4b573..c6f40879 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -208,7 +208,6 @@ class OllamaAPI: total_response = "" try: - # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts @@ -282,7 +281,7 @@ class OllamaAPI: "done": True, } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + return completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time From 9242f16cc13f2a9895c93d835c0642bc6d21d39d Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 03:17:27 +0800 Subject: [PATCH 3/6] Add error handling and improve logging for concurrent request testing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Added McpError and ErrorCode classes • Added detailed error collection logic • Improved error reporting & formatting • Added request ID tracking • Enhanced test results visibility --- test_lightrag_ollama_chat.py | 80 ++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 13 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 8cb633ba..d17d50fc 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -17,6 +17,19 @@ from typing import Dict, Any, Optional, List, Callable from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path +from enum import Enum, auto + +class ErrorCode(Enum): + """Error codes for MCP errors""" + InvalidRequest = auto() + InternalError = auto() + +class McpError(Exception): + """Base exception class for MCP errors""" + def __init__(self, code: ErrorCode, message: str): + self.code = code + self.message = message + super().__init__(message) DEFAULT_CONFIG = { "server": { @@ -641,35 +654,76 @@ def test_generate_concurrent() -> None: async with aiohttp.ClientSession() as session: yield session - async def make_request(session, prompt: str): + async def make_request(session, prompt: str, request_id: int): url = get_base_url("generate") data = create_generate_request_data(prompt, stream=False) try: async with session.post(url, json=data) as response: if response.status != 200: - response.raise_for_status() - return await response.json() + error_msg = f"Request {request_id} failed with status {response.status}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + result = await response.json() + if "error" in result: + error_msg = f"Request {request_id} returned error: {result['error']}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + return result except Exception as e: - return {"error": str(e)} + error_msg = f"Request {request_id} failed: {str(e)}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] - + async with get_session() as session: - tasks = [make_request(session, prompt) for prompt in prompts] - results = await asyncio.gather(*tasks) + tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 收集成功和失败的结果 + success_results = [] + error_messages = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + error_messages.append(f"Request {i+1} failed: {str(result)}") + else: + success_results.append((i+1, result)) + + # 如果有任何错误,在打印完所有结果后抛出异常 + if error_messages: + # 先打印成功的结果 + for req_id, result in success_results: + if OutputControl.is_verbose(): + print(f"\nRequest {req_id} succeeded:") + print_json_response(result) + + # 打印所有错误信息 + error_summary = "\n".join(error_messages) + raise McpError( + ErrorCode.InternalError, + f"Some concurrent requests failed:\n{error_summary}" + ) + return results if OutputControl.is_verbose(): print("\n=== Testing concurrent generate requests ===") # Run concurrent requests - results = asyncio.run(run_concurrent_requests()) - - # Print results - for i, result in enumerate(results, 1): - print(f"\nRequest {i} result:") - print_json_response(result) + try: + results = asyncio.run(run_concurrent_requests()) + # 如果没有异常,打印所有成功的结果 + for i, result in enumerate(results, 1): + print(f"\nRequest {i} result:") + print_json_response(result) + except McpError as e: + # 错误信息已经在之前打印过了,这里直接抛出 + raise def get_test_cases() -> Dict[str, Callable]: From e49d3665aa04b4748f9715aca496c57bb6ac9f97 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 03:18:05 +0800 Subject: [PATCH 4/6] Fix linting --- test_lightrag_ollama_chat.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index d17d50fc..08ecad67 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -19,18 +19,23 @@ from datetime import datetime from pathlib import Path from enum import Enum, auto + class ErrorCode(Enum): """Error codes for MCP errors""" + InvalidRequest = auto() InternalError = auto() + class McpError(Exception): """Base exception class for MCP errors""" + def __init__(self, code: ErrorCode, message: str): self.code = code self.message = message super().__init__(message) + DEFAULT_CONFIG = { "server": { "host": "localhost", @@ -660,13 +665,17 @@ def test_generate_concurrent() -> None: try: async with session.post(url, json=data) as response: if response.status != 200: - error_msg = f"Request {request_id} failed with status {response.status}" + error_msg = ( + f"Request {request_id} failed with status {response.status}" + ) if OutputControl.is_verbose(): print(f"\n{error_msg}") raise McpError(ErrorCode.InternalError, error_msg) result = await response.json() if "error" in result: - error_msg = f"Request {request_id} returned error: {result['error']}" + error_msg = ( + f"Request {request_id} returned error: {result['error']}" + ) if OutputControl.is_verbose(): print(f"\n{error_msg}") raise McpError(ErrorCode.InternalError, error_msg) @@ -679,21 +688,23 @@ def test_generate_concurrent() -> None: async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] - + async with get_session() as session: - tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)] + tasks = [ + make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts) + ] results = await asyncio.gather(*tasks, return_exceptions=True) - + # 收集成功和失败的结果 success_results = [] error_messages = [] - + for i, result in enumerate(results): if isinstance(result, Exception): error_messages.append(f"Request {i+1} failed: {str(result)}") else: - success_results.append((i+1, result)) - + success_results.append((i + 1, result)) + # 如果有任何错误,在打印完所有结果后抛出异常 if error_messages: # 先打印成功的结果 @@ -701,14 +712,14 @@ def test_generate_concurrent() -> None: if OutputControl.is_verbose(): print(f"\nRequest {req_id} succeeded:") print_json_response(result) - + # 打印所有错误信息 error_summary = "\n".join(error_messages) raise McpError( ErrorCode.InternalError, - f"Some concurrent requests failed:\n{error_summary}" + f"Some concurrent requests failed:\n{error_summary}", ) - + return results if OutputControl.is_verbose(): @@ -721,7 +732,7 @@ def test_generate_concurrent() -> None: for i, result in enumerate(results, 1): print(f"\nRequest {i} result:") print_json_response(result) - except McpError as e: + except McpError: # 错误信息已经在之前打印过了,这里直接抛出 raise From e124ad7f9cfd1c366b4b2c8fb1aaab2ffee1703e Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 04:53:05 +0800 Subject: [PATCH 5/6] Fix timing calculation logic in OllamaAPI stream generators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Initialize first_chunk_time as None • Set timing only when first chunk arrives --- lightrag/api/ollama_api.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index c6f40879..132601c3 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,14 +203,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = None + last_chunk_time = time.time_ns() total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts + first_chunk_time = last_chunk_time last_chunk_time = time.time_ns() total_response = response @@ -282,7 +283,8 @@ class OllamaAPI: } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return - + if first_chunk_time is None: + first_chunk_time = last_chunk_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time @@ -407,14 +409,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = None + last_chunk_time = time.time_ns() total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts + first_chunk_time = last_chunk_time last_chunk_time = time.time_ns() total_response = response @@ -499,6 +502,8 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return + if first_chunk_time is None: + first_chunk_time = last_chunk_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time From 9103e7f46339c10a0e7d66f477300d26fbf133b9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 10:42:49 +0800 Subject: [PATCH 6/6] fix: improve timing accuracy and variable scoping in OllamaAPI --- lightrag/api/ollama_api.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 132601c3..edf97993 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,15 +203,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = time.time_ns() - total_response = "" - try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = last_chunk_time + first_chunk_time = start_time last_chunk_time = time.time_ns() total_response = response @@ -284,7 +284,7 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return if first_chunk_time is None: - first_chunk_time = last_chunk_time + first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time @@ -409,15 +409,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = time.time_ns() - total_response = "" - try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = last_chunk_time + first_chunk_time = start_time last_chunk_time = time.time_ns() total_response = response @@ -503,7 +503,7 @@ class OllamaAPI: return if first_chunk_time is None: - first_chunk_time = last_chunk_time + first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time