From e26c6e564d52431d70d70d8a4740843feacac5f7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 02:43:06 +0800 Subject: [PATCH 01/11] refactor: enhance stream error handling and optimize code structure - Initialize timestamps at start to avoid null checks - Add detailed error handling for streaming response - Handle CancelledError and other exceptions separately - Unify exception handling with trace_exception - Clean up redundant code and simplify logic --- lightrag/api/ollama_api.py | 124 ++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 58 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 23b09f55..02b4b573 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,16 +203,16 @@ class OllamaAPI: ) async def stream_generator(): + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = "" + try: - first_chunk_time = None - last_chunk_time = None - total_response = "" # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -241,21 +241,48 @@ class OllamaAPI: } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() - last_chunk_time = time.time_ns() + last_chunk_time = time.time_ns() - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + + logging.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": f"\n\nError: {error_msg}", + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Send final message to close the stream + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time @@ -381,16 +408,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = None + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -474,45 +500,27 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return - if last_chunk_time is not None: - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" except Exception as e: - error_msg = f"Error in stream_generator: {str(e)}" - logging.error(error_msg) - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": {"code": "STREAM_ERROR", "message": error_msg}, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Ensure sending end marker - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + trace_exception(e) + raise return StreamingResponse( stream_generator(), From 65dc0a6cfd47746d2a77799b51be3eb87ef932bc Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 02:50:27 +0800 Subject: [PATCH 02/11] Fix linting --- lightrag/api/ollama_api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 02b4b573..c6f40879 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -208,7 +208,6 @@ class OllamaAPI: total_response = "" try: - # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts @@ -282,7 +281,7 @@ class OllamaAPI: "done": True, } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + return completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time From 9242f16cc13f2a9895c93d835c0642bc6d21d39d Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 03:17:27 +0800 Subject: [PATCH 03/11] Add error handling and improve logging for concurrent request testing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Added McpError and ErrorCode classes • Added detailed error collection logic • Improved error reporting & formatting • Added request ID tracking • Enhanced test results visibility --- test_lightrag_ollama_chat.py | 80 ++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 13 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 8cb633ba..d17d50fc 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -17,6 +17,19 @@ from typing import Dict, Any, Optional, List, Callable from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path +from enum import Enum, auto + +class ErrorCode(Enum): + """Error codes for MCP errors""" + InvalidRequest = auto() + InternalError = auto() + +class McpError(Exception): + """Base exception class for MCP errors""" + def __init__(self, code: ErrorCode, message: str): + self.code = code + self.message = message + super().__init__(message) DEFAULT_CONFIG = { "server": { @@ -641,35 +654,76 @@ def test_generate_concurrent() -> None: async with aiohttp.ClientSession() as session: yield session - async def make_request(session, prompt: str): + async def make_request(session, prompt: str, request_id: int): url = get_base_url("generate") data = create_generate_request_data(prompt, stream=False) try: async with session.post(url, json=data) as response: if response.status != 200: - response.raise_for_status() - return await response.json() + error_msg = f"Request {request_id} failed with status {response.status}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + result = await response.json() + if "error" in result: + error_msg = f"Request {request_id} returned error: {result['error']}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + return result except Exception as e: - return {"error": str(e)} + error_msg = f"Request {request_id} failed: {str(e)}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] - + async with get_session() as session: - tasks = [make_request(session, prompt) for prompt in prompts] - results = await asyncio.gather(*tasks) + tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 收集成功和失败的结果 + success_results = [] + error_messages = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + error_messages.append(f"Request {i+1} failed: {str(result)}") + else: + success_results.append((i+1, result)) + + # 如果有任何错误,在打印完所有结果后抛出异常 + if error_messages: + # 先打印成功的结果 + for req_id, result in success_results: + if OutputControl.is_verbose(): + print(f"\nRequest {req_id} succeeded:") + print_json_response(result) + + # 打印所有错误信息 + error_summary = "\n".join(error_messages) + raise McpError( + ErrorCode.InternalError, + f"Some concurrent requests failed:\n{error_summary}" + ) + return results if OutputControl.is_verbose(): print("\n=== Testing concurrent generate requests ===") # Run concurrent requests - results = asyncio.run(run_concurrent_requests()) - - # Print results - for i, result in enumerate(results, 1): - print(f"\nRequest {i} result:") - print_json_response(result) + try: + results = asyncio.run(run_concurrent_requests()) + # 如果没有异常,打印所有成功的结果 + for i, result in enumerate(results, 1): + print(f"\nRequest {i} result:") + print_json_response(result) + except McpError as e: + # 错误信息已经在之前打印过了,这里直接抛出 + raise def get_test_cases() -> Dict[str, Callable]: From e49d3665aa04b4748f9715aca496c57bb6ac9f97 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 03:18:05 +0800 Subject: [PATCH 04/11] Fix linting --- test_lightrag_ollama_chat.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index d17d50fc..08ecad67 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -19,18 +19,23 @@ from datetime import datetime from pathlib import Path from enum import Enum, auto + class ErrorCode(Enum): """Error codes for MCP errors""" + InvalidRequest = auto() InternalError = auto() + class McpError(Exception): """Base exception class for MCP errors""" + def __init__(self, code: ErrorCode, message: str): self.code = code self.message = message super().__init__(message) + DEFAULT_CONFIG = { "server": { "host": "localhost", @@ -660,13 +665,17 @@ def test_generate_concurrent() -> None: try: async with session.post(url, json=data) as response: if response.status != 200: - error_msg = f"Request {request_id} failed with status {response.status}" + error_msg = ( + f"Request {request_id} failed with status {response.status}" + ) if OutputControl.is_verbose(): print(f"\n{error_msg}") raise McpError(ErrorCode.InternalError, error_msg) result = await response.json() if "error" in result: - error_msg = f"Request {request_id} returned error: {result['error']}" + error_msg = ( + f"Request {request_id} returned error: {result['error']}" + ) if OutputControl.is_verbose(): print(f"\n{error_msg}") raise McpError(ErrorCode.InternalError, error_msg) @@ -679,21 +688,23 @@ def test_generate_concurrent() -> None: async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] - + async with get_session() as session: - tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)] + tasks = [ + make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts) + ] results = await asyncio.gather(*tasks, return_exceptions=True) - + # 收集成功和失败的结果 success_results = [] error_messages = [] - + for i, result in enumerate(results): if isinstance(result, Exception): error_messages.append(f"Request {i+1} failed: {str(result)}") else: - success_results.append((i+1, result)) - + success_results.append((i + 1, result)) + # 如果有任何错误,在打印完所有结果后抛出异常 if error_messages: # 先打印成功的结果 @@ -701,14 +712,14 @@ def test_generate_concurrent() -> None: if OutputControl.is_verbose(): print(f"\nRequest {req_id} succeeded:") print_json_response(result) - + # 打印所有错误信息 error_summary = "\n".join(error_messages) raise McpError( ErrorCode.InternalError, - f"Some concurrent requests failed:\n{error_summary}" + f"Some concurrent requests failed:\n{error_summary}", ) - + return results if OutputControl.is_verbose(): @@ -721,7 +732,7 @@ def test_generate_concurrent() -> None: for i, result in enumerate(results, 1): print(f"\nRequest {i} result:") print_json_response(result) - except McpError as e: + except McpError: # 错误信息已经在之前打印过了,这里直接抛出 raise From e124ad7f9cfd1c366b4b2c8fb1aaab2ffee1703e Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 04:53:05 +0800 Subject: [PATCH 05/11] Fix timing calculation logic in OllamaAPI stream generators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Initialize first_chunk_time as None • Set timing only when first chunk arrives --- lightrag/api/ollama_api.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index c6f40879..132601c3 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,14 +203,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = None + last_chunk_time = time.time_ns() total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts + first_chunk_time = last_chunk_time last_chunk_time = time.time_ns() total_response = response @@ -282,7 +283,8 @@ class OllamaAPI: } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return - + if first_chunk_time is None: + first_chunk_time = last_chunk_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time @@ -407,14 +409,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = None + last_chunk_time = time.time_ns() total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts + first_chunk_time = last_chunk_time last_chunk_time = time.time_ns() total_response = response @@ -499,6 +502,8 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return + if first_chunk_time is None: + first_chunk_time = last_chunk_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time From 9103e7f46339c10a0e7d66f477300d26fbf133b9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 10:42:49 +0800 Subject: [PATCH 06/11] fix: improve timing accuracy and variable scoping in OllamaAPI --- lightrag/api/ollama_api.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 132601c3..edf97993 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,15 +203,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = time.time_ns() - total_response = "" - try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = last_chunk_time + first_chunk_time = start_time last_chunk_time = time.time_ns() total_response = response @@ -284,7 +284,7 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return if first_chunk_time is None: - first_chunk_time = last_chunk_time + first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time @@ -409,15 +409,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = time.time_ns() - total_response = "" - try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = last_chunk_time + first_chunk_time = start_time last_chunk_time = time.time_ns() total_response = response @@ -503,7 +503,7 @@ class OllamaAPI: return if first_chunk_time is None: - first_chunk_time = last_chunk_time + first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time From 506e39e14ed48e9a014937eb78fbdd21a1ecbb60 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 19:42:57 +0800 Subject: [PATCH 07/11] Enhance OpenAI API error handling and logging for better reliability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add InvalidResponseError custom exception • Improve error logging for API failures • Add empty response content validation • Add more detailed debug logging info • Add retry for invalid response cases --- lightrag/llm/openai.py | 56 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 4ba06d2a..3f939d62 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -77,12 +77,15 @@ from lightrag.types import GPTKeywordExtractionFormat import numpy as np from typing import Union +class InvalidResponseError(Exception): + """Custom exception class for triggering retry mechanism""" + pass @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APITimeoutError) + (RateLimitError, APIConnectionError, APITimeoutError, InvalidResponseError) ), ) async def openai_complete_if_cache( @@ -112,17 +115,35 @@ async def openai_complete_if_cache( # 添加日志输出 logger.debug("===== Query Input to LLM =====") + logger.debug(f"Model: {model} Base URL: {base_url}") + logger.debug(f"Additional kwargs: {kwargs}") logger.debug(f"Query: {prompt}") logger.debug(f"System prompt: {system_prompt}") - logger.debug("Full context:") - if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( - model=model, messages=messages, **kwargs - ) - else: - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) + # logger.debug(f"Messages: {messages}") + + try: + if "response_format" in kwargs: + response = await openai_async_client.beta.chat.completions.parse( + model=model, messages=messages, **kwargs + ) + else: + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + except APIConnectionError as e: + logger.error(f"OpenAI API Connection Error: {str(e)}") + raise + except RateLimitError as e: + logger.error(f"OpenAI API Rate Limit Error: {str(e)}") + raise + except APITimeoutError as e: + logger.error(f"OpenAI API Timeout Error: {str(e)}") + raise + except Exception as e: + logger.error(f"OpenAI API Call Failed: {str(e)}") + logger.error(f"Model: {model}") + logger.error(f"Request parameters: {kwargs}") + raise if hasattr(response, "__aiter__"): @@ -140,8 +161,23 @@ async def openai_complete_if_cache( raise return inner() + else: + if ( + not response + or not response.choices + or not hasattr(response.choices[0], "message") + or not hasattr(response.choices[0].message, "content") + ): + logger.error("Invalid response from OpenAI API") + raise InvalidResponseError("Invalid response from OpenAI API") + content = response.choices[0].message.content + + if not content or content.strip() == "": + logger.error("Received empty content from OpenAI API") + raise InvalidResponseError("Received empty content from OpenAI API") + if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) return content From 6ca7487cac5cf0c56a98bd737fcb730add56c0a9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 19:56:18 +0800 Subject: [PATCH 08/11] Update timeout and max_retries for unit tests --- test_lightrag_ollama_chat.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 8cb633ba..cadd02cb 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -23,8 +23,8 @@ DEFAULT_CONFIG = { "host": "localhost", "port": 9621, "model": "lightrag:latest", - "timeout": 120, - "max_retries": 3, + "timeout": 300, + "max_retries": 1, "retry_delay": 1, }, "test_cases": { @@ -527,14 +527,7 @@ def test_non_stream_generate() -> None: response_json = response.json() # Print response content - print_json_response( - { - "model": response_json["model"], - "response": response_json["response"], - "done": response_json["done"], - }, - "Response content", - ) + print(json.dumps(response_json, ensure_ascii=False, indent=2)) def test_stream_generate() -> None: From 2760433634275ca1dc1a28802fe58612e8110760 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 22:55:22 +0800 Subject: [PATCH 09/11] Add LightRAG version to User-Agent header for better request tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add User-Agent header with version info • Update header creation in Ollama client • Update header creation in OpenAI client • Ensure consistent header format • Include Mozilla UA string for OpenAI --- lightrag/llm/ollama.py | 23 +++++++++++++---------- lightrag/llm/openai.py | 15 +++++++++++++-- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 19f560e7..c65954f1 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -66,6 +66,7 @@ from lightrag.exceptions import ( RateLimitError, APITimeoutError, ) +from lightrag.api import __api_version__ import numpy as np from typing import Union @@ -91,11 +92,12 @@ async def ollama_model_if_cache( timeout = kwargs.pop("timeout", None) kwargs.pop("hashing_kv", None) api_key = kwargs.pop("api_key", None) - headers = ( - {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - if api_key - else {"Content-Type": "application/json"} - ) + headers = { + "Content-Type": "application/json", + "User-Agent": f"LightRAG/{__api_version__}" + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) messages = [] if system_prompt: @@ -147,11 +149,12 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: api_key = kwargs.pop("api_key", None) - headers = ( - {"Content-Type": "application/json", "Authorization": api_key} - if api_key - else {"Content-Type": "application/json"} - ) + headers = { + "Content-Type": "application/json", + "User-Agent": f"LightRAG/{__api_version__}" + } + if api_key: + headers["Authorization"] = api_key kwargs["headers"] = headers ollama_client = ollama.Client(**kwargs) data = ollama_client.embed(model=embed_model, input=texts) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 3f939d62..ca451bcf 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -73,6 +73,7 @@ from lightrag.utils import ( logger, ) from lightrag.types import GPTKeywordExtractionFormat +from lightrag.api import __api_version__ import numpy as np from typing import Union @@ -102,8 +103,13 @@ async def openai_complete_if_cache( if api_key: os.environ["OPENAI_API_KEY"] = api_key + default_headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json" + } openai_async_client = ( - AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + AsyncOpenAI(default_headers=default_headers) if base_url is None + else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) @@ -287,8 +293,13 @@ async def openai_embed( if api_key: os.environ["OPENAI_API_KEY"] = api_key + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json" + } openai_async_client = ( - AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + AsyncOpenAI(default_headers=default_headers) if base_url is None + else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" From c838229f0504a796ba0c348256ab3949f28de16d Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 23:10:15 +0800 Subject: [PATCH 10/11] Remove Chinese comments and added English comments for clarity --- test_lightrag_ollama_chat.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index cfd7557f..80038928 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -688,7 +688,6 @@ def test_generate_concurrent() -> None: ] results = await asyncio.gather(*tasks, return_exceptions=True) - # 收集成功和失败的结果 success_results = [] error_messages = [] @@ -698,15 +697,12 @@ def test_generate_concurrent() -> None: else: success_results.append((i + 1, result)) - # 如果有任何错误,在打印完所有结果后抛出异常 if error_messages: - # 先打印成功的结果 for req_id, result in success_results: if OutputControl.is_verbose(): print(f"\nRequest {req_id} succeeded:") print_json_response(result) - # 打印所有错误信息 error_summary = "\n".join(error_messages) raise McpError( ErrorCode.InternalError, @@ -721,12 +717,12 @@ def test_generate_concurrent() -> None: # Run concurrent requests try: results = asyncio.run(run_concurrent_requests()) - # 如果没有异常,打印所有成功的结果 + # all success, print out results for i, result in enumerate(results, 1): print(f"\nRequest {i} result:") print_json_response(result) except McpError: - # 错误信息已经在之前打印过了,这里直接抛出 + # error message already printed raise From 8fdba5d4db5f799013ffdee9dc7a716023751640 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 23:12:35 +0800 Subject: [PATCH 11/11] Fix linting --- lightrag/llm/ollama.py | 4 ++-- lightrag/llm/openai.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index c65954f1..9e38917a 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -94,7 +94,7 @@ async def ollama_model_if_cache( api_key = kwargs.pop("api_key", None) headers = { "Content-Type": "application/json", - "User-Agent": f"LightRAG/{__api_version__}" + "User-Agent": f"LightRAG/{__api_version__}", } if api_key: headers["Authorization"] = f"Bearer {api_key}" @@ -151,7 +151,7 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: api_key = kwargs.pop("api_key", None) headers = { "Content-Type": "application/json", - "User-Agent": f"LightRAG/{__api_version__}" + "User-Agent": f"LightRAG/{__api_version__}", } if api_key: headers["Authorization"] = api_key diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index ca451bcf..535d665c 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -78,10 +78,13 @@ from lightrag.api import __api_version__ import numpy as np from typing import Union + class InvalidResponseError(Exception): """Custom exception class for triggering retry mechanism""" + pass + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -105,10 +108,11 @@ async def openai_complete_if_cache( default_headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json" + "Content-Type": "application/json", } openai_async_client = ( - AsyncOpenAI(default_headers=default_headers) if base_url is None + AsyncOpenAI(default_headers=default_headers) + if base_url is None else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) kwargs.pop("hashing_kv", None) @@ -295,10 +299,11 @@ async def openai_embed( default_headers = { "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json" + "Content-Type": "application/json", } openai_async_client = ( - AsyncOpenAI(default_headers=default_headers) if base_url is None + AsyncOpenAI(default_headers=default_headers) + if base_url is None else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) response = await openai_async_client.embeddings.create(