diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 23b09f55..edf97993 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -205,14 +205,14 @@ class OllamaAPI: async def stream_generator(): try: first_chunk_time = None - last_chunk_time = None + last_chunk_time = time.time_ns() total_response = "" # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = start_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -241,22 +241,50 @@ class OllamaAPI: } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() - last_chunk_time = time.time_ns() + last_chunk_time = time.time_ns() - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + logging.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": f"\n\nError: {error_msg}", + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Send final message to close the stream + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return + if first_chunk_time is None: + first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time @@ -381,16 +409,16 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = None - total_response = "" - try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = start_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -474,45 +502,29 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return - if last_chunk_time is not None: - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + if first_chunk_time is None: + first_chunk_time = start_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" except Exception as e: - error_msg = f"Error in stream_generator: {str(e)}" - logging.error(error_msg) - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": {"code": "STREAM_ERROR", "message": error_msg}, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Ensure sending end marker - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + trace_exception(e) + raise return StreamingResponse( stream_generator(), diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 19f560e7..9e38917a 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -66,6 +66,7 @@ from lightrag.exceptions import ( RateLimitError, APITimeoutError, ) +from lightrag.api import __api_version__ import numpy as np from typing import Union @@ -91,11 +92,12 @@ async def ollama_model_if_cache( timeout = kwargs.pop("timeout", None) kwargs.pop("hashing_kv", None) api_key = kwargs.pop("api_key", None) - headers = ( - {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - if api_key - else {"Content-Type": "application/json"} - ) + headers = { + "Content-Type": "application/json", + "User-Agent": f"LightRAG/{__api_version__}", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) messages = [] if system_prompt: @@ -147,11 +149,12 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: api_key = kwargs.pop("api_key", None) - headers = ( - {"Content-Type": "application/json", "Authorization": api_key} - if api_key - else {"Content-Type": "application/json"} - ) + headers = { + "Content-Type": "application/json", + "User-Agent": f"LightRAG/{__api_version__}", + } + if api_key: + headers["Authorization"] = api_key kwargs["headers"] = headers ollama_client = ollama.Client(**kwargs) data = ollama_client.embed(model=embed_model, input=texts) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 4ba06d2a..535d665c 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -73,16 +73,23 @@ from lightrag.utils import ( logger, ) from lightrag.types import GPTKeywordExtractionFormat +from lightrag.api import __api_version__ import numpy as np from typing import Union +class InvalidResponseError(Exception): + """Custom exception class for triggering retry mechanism""" + + pass + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APITimeoutError) + (RateLimitError, APIConnectionError, APITimeoutError, InvalidResponseError) ), ) async def openai_complete_if_cache( @@ -99,8 +106,14 @@ async def openai_complete_if_cache( if api_key: os.environ["OPENAI_API_KEY"] = api_key + default_headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json", + } openai_async_client = ( - AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + AsyncOpenAI(default_headers=default_headers) + if base_url is None + else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) @@ -112,17 +125,35 @@ async def openai_complete_if_cache( # 添加日志输出 logger.debug("===== Query Input to LLM =====") + logger.debug(f"Model: {model} Base URL: {base_url}") + logger.debug(f"Additional kwargs: {kwargs}") logger.debug(f"Query: {prompt}") logger.debug(f"System prompt: {system_prompt}") - logger.debug("Full context:") - if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( - model=model, messages=messages, **kwargs - ) - else: - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) + # logger.debug(f"Messages: {messages}") + + try: + if "response_format" in kwargs: + response = await openai_async_client.beta.chat.completions.parse( + model=model, messages=messages, **kwargs + ) + else: + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + except APIConnectionError as e: + logger.error(f"OpenAI API Connection Error: {str(e)}") + raise + except RateLimitError as e: + logger.error(f"OpenAI API Rate Limit Error: {str(e)}") + raise + except APITimeoutError as e: + logger.error(f"OpenAI API Timeout Error: {str(e)}") + raise + except Exception as e: + logger.error(f"OpenAI API Call Failed: {str(e)}") + logger.error(f"Model: {model}") + logger.error(f"Request parameters: {kwargs}") + raise if hasattr(response, "__aiter__"): @@ -140,8 +171,23 @@ async def openai_complete_if_cache( raise return inner() + else: + if ( + not response + or not response.choices + or not hasattr(response.choices[0], "message") + or not hasattr(response.choices[0].message, "content") + ): + logger.error("Invalid response from OpenAI API") + raise InvalidResponseError("Invalid response from OpenAI API") + content = response.choices[0].message.content + + if not content or content.strip() == "": + logger.error("Received empty content from OpenAI API") + raise InvalidResponseError("Received empty content from OpenAI API") + if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) return content @@ -251,8 +297,14 @@ async def openai_embed( if api_key: os.environ["OPENAI_API_KEY"] = api_key + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json", + } openai_async_client = ( - AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + AsyncOpenAI(default_headers=default_headers) + if base_url is None + else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 8cb633ba..80038928 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -17,14 +17,32 @@ from typing import Dict, Any, Optional, List, Callable from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path +from enum import Enum, auto + + +class ErrorCode(Enum): + """Error codes for MCP errors""" + + InvalidRequest = auto() + InternalError = auto() + + +class McpError(Exception): + """Base exception class for MCP errors""" + + def __init__(self, code: ErrorCode, message: str): + self.code = code + self.message = message + super().__init__(message) + DEFAULT_CONFIG = { "server": { "host": "localhost", "port": 9621, "model": "lightrag:latest", - "timeout": 120, - "max_retries": 3, + "timeout": 300, + "max_retries": 1, "retry_delay": 1, }, "test_cases": { @@ -527,14 +545,7 @@ def test_non_stream_generate() -> None: response_json = response.json() # Print response content - print_json_response( - { - "model": response_json["model"], - "response": response_json["response"], - "done": response_json["done"], - }, - "Response content", - ) + print(json.dumps(response_json, ensure_ascii=False, indent=2)) def test_stream_generate() -> None: @@ -641,35 +652,78 @@ def test_generate_concurrent() -> None: async with aiohttp.ClientSession() as session: yield session - async def make_request(session, prompt: str): + async def make_request(session, prompt: str, request_id: int): url = get_base_url("generate") data = create_generate_request_data(prompt, stream=False) try: async with session.post(url, json=data) as response: if response.status != 200: - response.raise_for_status() - return await response.json() + error_msg = ( + f"Request {request_id} failed with status {response.status}" + ) + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + result = await response.json() + if "error" in result: + error_msg = ( + f"Request {request_id} returned error: {result['error']}" + ) + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + return result except Exception as e: - return {"error": str(e)} + error_msg = f"Request {request_id} failed: {str(e)}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] async with get_session() as session: - tasks = [make_request(session, prompt) for prompt in prompts] - results = await asyncio.gather(*tasks) + tasks = [ + make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts) + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + success_results = [] + error_messages = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + error_messages.append(f"Request {i+1} failed: {str(result)}") + else: + success_results.append((i + 1, result)) + + if error_messages: + for req_id, result in success_results: + if OutputControl.is_verbose(): + print(f"\nRequest {req_id} succeeded:") + print_json_response(result) + + error_summary = "\n".join(error_messages) + raise McpError( + ErrorCode.InternalError, + f"Some concurrent requests failed:\n{error_summary}", + ) + return results if OutputControl.is_verbose(): print("\n=== Testing concurrent generate requests ===") # Run concurrent requests - results = asyncio.run(run_concurrent_requests()) - - # Print results - for i, result in enumerate(results, 1): - print(f"\nRequest {i} result:") - print_json_response(result) + try: + results = asyncio.run(run_concurrent_requests()) + # all success, print out results + for i, result in enumerate(results, 1): + print(f"\nRequest {i} result:") + print_json_response(result) + except McpError: + # error message already printed + raise def get_test_cases() -> Dict[str, Callable]: