From 19ee3d109c10d45e9501d7f3232db527e02a8c4b Mon Sep 17 00:00:00 2001 From: ultrageopro Date: Thu, 6 Feb 2025 22:56:17 +0300 Subject: [PATCH 01/12] =?UTF-8?q?feat:=20trimming=20the=20model=E2=80=99s?= =?UTF-8?q?=20reasoning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 ++++++ lightrag/llm/ollama.py | 18 ++++++++++++++++-- lightrag/utils.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d9315458..456d9a72 100644 --- a/README.md +++ b/README.md @@ -338,6 +338,12 @@ rag = LightRAG( There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k. +#### Using "Thinking" Models (e.g., DeepSeek) + +To return only the model's response, you can pass `reasoning_tag` in `llm_model_kwargs`. + +For example, for DeepSeek models, `reasoning_tag` should be set to `think`. + #### Low RAM GPUs In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`. diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 19f560e7..3541bd67 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -66,6 +66,7 @@ from lightrag.exceptions import ( RateLimitError, APITimeoutError, ) +from lightrag.utils import extract_reasoning import numpy as np from typing import Union @@ -85,6 +86,7 @@ async def ollama_model_if_cache( **kwargs, ) -> Union[str, AsyncIterator[str]]: stream = True if kwargs.get("stream") else False + reasoning_tag = kwargs.pop("reasoning_tag", None) kwargs.pop("max_tokens", None) # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) @@ -105,7 +107,7 @@ async def ollama_model_if_cache( response = await ollama_client.chat(model=model, messages=messages, **kwargs) if stream: - """cannot cache stream response""" + """cannot cache stream response and process reasoning""" async def inner(): async for chunk in response: @@ -113,7 +115,19 @@ async def ollama_model_if_cache( return inner() else: - return response["message"]["content"] + model_response = response["message"]["content"] + + """ + If the model also wraps its thoughts in a specific tag, + this information is not needed for the final + response and can simply be trimmed. + """ + + return ( + model_response + if reasoning_tag is None + else extract_reasoning(model_response, reasoning_tag).response_content + ) async def ollama_model_complete( diff --git a/lightrag/utils.py b/lightrag/utils.py index 3a69513b..ed0b6c06 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -11,6 +11,7 @@ from functools import wraps from hashlib import md5 from typing import Any, Union, List, Optional import xml.etree.ElementTree as ET +import bs4 import numpy as np import tiktoken @@ -64,6 +65,13 @@ class EmbeddingFunc: return await self.func(*args, **kwargs) +@dataclass +class ReasoningResponse: + reasoning_content: str + response_content: str + tag: str + + def locate_json_string_body_from_string(content: str) -> Union[str, None]: """Locate the JSON string body from a string""" try: @@ -666,3 +674,28 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> ) return "\n".join(formatted_turns) + + +def extract_reasoning(response: str, tag: str) -> ReasoningResponse: + """Extract the reasoning section and the following section from the LLM response. + + Args: + response: LLM response + tag: Tag to extract + Returns: + ReasoningResponse: Reasoning section and following section + + """ + soup = bs4.BeautifulSoup(response, "html.parser") + + reasoning_section = soup.find(tag) + if reasoning_section is None: + return ReasoningResponse(None, response, tag) + reasoning_content = reasoning_section.get_text().strip() + + after_reasoning_section = reasoning_section.next_sibling + if after_reasoning_section is None: + return ReasoningResponse(reasoning_content, "", tag) + after_reasoning_content = after_reasoning_section.get_text().strip() + + return ReasoningResponse(reasoning_content, after_reasoning_content, tag) From f72e4e68302c2e4430d3195575368906e54b0ed2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 19:42:57 +0800 Subject: [PATCH 02/12] Enhance OpenAI API error handling and logging for better reliability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add InvalidResponseError custom exception • Improve error logging for API failures • Add empty response content validation • Add more detailed debug logging info • Add retry for invalid response cases --- lightrag/llm/openai.py | 56 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 4ba06d2a..3f939d62 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -77,12 +77,15 @@ from lightrag.types import GPTKeywordExtractionFormat import numpy as np from typing import Union +class InvalidResponseError(Exception): + """Custom exception class for triggering retry mechanism""" + pass @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APITimeoutError) + (RateLimitError, APIConnectionError, APITimeoutError, InvalidResponseError) ), ) async def openai_complete_if_cache( @@ -112,17 +115,35 @@ async def openai_complete_if_cache( # 添加日志输出 logger.debug("===== Query Input to LLM =====") + logger.debug(f"Model: {model} Base URL: {base_url}") + logger.debug(f"Additional kwargs: {kwargs}") logger.debug(f"Query: {prompt}") logger.debug(f"System prompt: {system_prompt}") - logger.debug("Full context:") - if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( - model=model, messages=messages, **kwargs - ) - else: - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) + # logger.debug(f"Messages: {messages}") + + try: + if "response_format" in kwargs: + response = await openai_async_client.beta.chat.completions.parse( + model=model, messages=messages, **kwargs + ) + else: + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + except APIConnectionError as e: + logger.error(f"OpenAI API Connection Error: {str(e)}") + raise + except RateLimitError as e: + logger.error(f"OpenAI API Rate Limit Error: {str(e)}") + raise + except APITimeoutError as e: + logger.error(f"OpenAI API Timeout Error: {str(e)}") + raise + except Exception as e: + logger.error(f"OpenAI API Call Failed: {str(e)}") + logger.error(f"Model: {model}") + logger.error(f"Request parameters: {kwargs}") + raise if hasattr(response, "__aiter__"): @@ -140,8 +161,23 @@ async def openai_complete_if_cache( raise return inner() + else: + if ( + not response + or not response.choices + or not hasattr(response.choices[0], "message") + or not hasattr(response.choices[0].message, "content") + ): + logger.error("Invalid response from OpenAI API") + raise InvalidResponseError("Invalid response from OpenAI API") + content = response.choices[0].message.content + + if not content or content.strip() == "": + logger.error("Received empty content from OpenAI API") + raise InvalidResponseError("Received empty content from OpenAI API") + if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) return content From cec8da7f913c7e48e54804ecb36d6e0120141a26 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 19:56:18 +0800 Subject: [PATCH 03/12] Update timeout and max_retries for unit tests --- test_lightrag_ollama_chat.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index 8cb633ba..cadd02cb 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -23,8 +23,8 @@ DEFAULT_CONFIG = { "host": "localhost", "port": 9621, "model": "lightrag:latest", - "timeout": 120, - "max_retries": 3, + "timeout": 300, + "max_retries": 1, "retry_delay": 1, }, "test_cases": { @@ -527,14 +527,7 @@ def test_non_stream_generate() -> None: response_json = response.json() # Print response content - print_json_response( - { - "model": response_json["model"], - "response": response_json["response"], - "done": response_json["done"], - }, - "Response content", - ) + print(json.dumps(response_json, ensure_ascii=False, indent=2)) def test_stream_generate() -> None: From b90f3f14be2d7f7c630c70acc7b05933b555f48e Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 22:55:22 +0800 Subject: [PATCH 04/12] Add LightRAG version to User-Agent header for better request tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add User-Agent header with version info • Update header creation in Ollama client • Update header creation in OpenAI client • Ensure consistent header format • Include Mozilla UA string for OpenAI --- lightrag/llm/ollama.py | 23 +++++++++++++---------- lightrag/llm/openai.py | 15 +++++++++++++-- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 3541bd67..296e263e 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -66,6 +66,7 @@ from lightrag.exceptions import ( RateLimitError, APITimeoutError, ) +from lightrag.api import __api_version__ from lightrag.utils import extract_reasoning import numpy as np from typing import Union @@ -93,11 +94,12 @@ async def ollama_model_if_cache( timeout = kwargs.pop("timeout", None) kwargs.pop("hashing_kv", None) api_key = kwargs.pop("api_key", None) - headers = ( - {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - if api_key - else {"Content-Type": "application/json"} - ) + headers = { + "Content-Type": "application/json", + "User-Agent": f"LightRAG/{__api_version__}", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) messages = [] if system_prompt: @@ -161,11 +163,12 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: api_key = kwargs.pop("api_key", None) - headers = ( - {"Content-Type": "application/json", "Authorization": api_key} - if api_key - else {"Content-Type": "application/json"} - ) + headers = { + "Content-Type": "application/json", + "User-Agent": f"LightRAG/{__api_version__}", + } + if api_key: + headers["Authorization"] = api_key kwargs["headers"] = headers ollama_client = ollama.Client(**kwargs) data = ollama_client.embed(model=embed_model, input=texts) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 3f939d62..ca451bcf 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -73,6 +73,7 @@ from lightrag.utils import ( logger, ) from lightrag.types import GPTKeywordExtractionFormat +from lightrag.api import __api_version__ import numpy as np from typing import Union @@ -102,8 +103,13 @@ async def openai_complete_if_cache( if api_key: os.environ["OPENAI_API_KEY"] = api_key + default_headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json" + } openai_async_client = ( - AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + AsyncOpenAI(default_headers=default_headers) if base_url is None + else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) @@ -287,8 +293,13 @@ async def openai_embed( if api_key: os.environ["OPENAI_API_KEY"] = api_key + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json" + } openai_async_client = ( - AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + AsyncOpenAI(default_headers=default_headers) if base_url is None + else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" From 1508dcb40336850b35cd609006efc736ed377188 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 02:43:06 +0800 Subject: [PATCH 05/12] refactor: enhance stream error handling and optimize code structure - Initialize timestamps at start to avoid null checks - Add detailed error handling for streaming response - Handle CancelledError and other exceptions separately - Unify exception handling with trace_exception - Clean up redundant code and simplify logic --- lightrag/api/ollama_api.py | 124 ++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 58 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 23b09f55..02b4b573 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,16 +203,16 @@ class OllamaAPI: ) async def stream_generator(): + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = "" + try: - first_chunk_time = None - last_chunk_time = None - total_response = "" # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -241,21 +241,48 @@ class OllamaAPI: } yield f"{json.dumps(data, ensure_ascii=False)}\n" else: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() - last_chunk_time = time.time_ns() + last_chunk_time = time.time_ns() - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + total_response += chunk + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": chunk, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" + + logging.error(f"Stream error: {error_msg}") + + # Send error message to client + error_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "response": f"\n\nError: {error_msg}", + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + + # Send final message to close the stream + final_data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time @@ -381,16 +408,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = None + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + last_chunk_time = time.time_ns() total_response = response data = { @@ -474,45 +500,27 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return - if last_chunk_time is not None: - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + data = { + "model": self.ollama_server_infos.LIGHTRAG_MODEL, + "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" except Exception as e: - error_msg = f"Error in stream_generator: {str(e)}" - logging.error(error_msg) - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "error": {"code": "STREAM_ERROR", "message": error_msg}, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Ensure sending end marker - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + trace_exception(e) + raise return StreamingResponse( stream_generator(), From 4bdb8e30de5cb656a2958e2762cc1746ec0077ae Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 02:50:27 +0800 Subject: [PATCH 06/12] Fix linting --- lightrag/api/ollama_api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 02b4b573..c6f40879 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -208,7 +208,6 @@ class OllamaAPI: total_response = "" try: - # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts @@ -282,7 +281,7 @@ class OllamaAPI: "done": True, } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + return completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time From e4476ef0afeb6253a9d2794464c48ead858bfef1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 03:17:27 +0800 Subject: [PATCH 07/12] Add error handling and improve logging for concurrent request testing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Added McpError and ErrorCode classes • Added detailed error collection logic • Improved error reporting & formatting • Added request ID tracking • Enhanced test results visibility --- test_lightrag_ollama_chat.py | 80 ++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 13 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index cadd02cb..c07fe792 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -17,6 +17,19 @@ from typing import Dict, Any, Optional, List, Callable from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path +from enum import Enum, auto + +class ErrorCode(Enum): + """Error codes for MCP errors""" + InvalidRequest = auto() + InternalError = auto() + +class McpError(Exception): + """Base exception class for MCP errors""" + def __init__(self, code: ErrorCode, message: str): + self.code = code + self.message = message + super().__init__(message) DEFAULT_CONFIG = { "server": { @@ -634,35 +647,76 @@ def test_generate_concurrent() -> None: async with aiohttp.ClientSession() as session: yield session - async def make_request(session, prompt: str): + async def make_request(session, prompt: str, request_id: int): url = get_base_url("generate") data = create_generate_request_data(prompt, stream=False) try: async with session.post(url, json=data) as response: if response.status != 200: - response.raise_for_status() - return await response.json() + error_msg = f"Request {request_id} failed with status {response.status}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + result = await response.json() + if "error" in result: + error_msg = f"Request {request_id} returned error: {result['error']}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) + return result except Exception as e: - return {"error": str(e)} + error_msg = f"Request {request_id} failed: {str(e)}" + if OutputControl.is_verbose(): + print(f"\n{error_msg}") + raise McpError(ErrorCode.InternalError, error_msg) async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] - + async with get_session() as session: - tasks = [make_request(session, prompt) for prompt in prompts] - results = await asyncio.gather(*tasks) + tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 收集成功和失败的结果 + success_results = [] + error_messages = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + error_messages.append(f"Request {i+1} failed: {str(result)}") + else: + success_results.append((i+1, result)) + + # 如果有任何错误,在打印完所有结果后抛出异常 + if error_messages: + # 先打印成功的结果 + for req_id, result in success_results: + if OutputControl.is_verbose(): + print(f"\nRequest {req_id} succeeded:") + print_json_response(result) + + # 打印所有错误信息 + error_summary = "\n".join(error_messages) + raise McpError( + ErrorCode.InternalError, + f"Some concurrent requests failed:\n{error_summary}" + ) + return results if OutputControl.is_verbose(): print("\n=== Testing concurrent generate requests ===") # Run concurrent requests - results = asyncio.run(run_concurrent_requests()) - - # Print results - for i, result in enumerate(results, 1): - print(f"\nRequest {i} result:") - print_json_response(result) + try: + results = asyncio.run(run_concurrent_requests()) + # 如果没有异常,打印所有成功的结果 + for i, result in enumerate(results, 1): + print(f"\nRequest {i} result:") + print_json_response(result) + except McpError as e: + # 错误信息已经在之前打印过了,这里直接抛出 + raise def get_test_cases() -> Dict[str, Callable]: From d297a871905611948742afb082a2c422ae965084 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 03:18:05 +0800 Subject: [PATCH 08/12] Fix linting --- test_lightrag_ollama_chat.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index c07fe792..cfd7557f 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -19,18 +19,23 @@ from datetime import datetime from pathlib import Path from enum import Enum, auto + class ErrorCode(Enum): """Error codes for MCP errors""" + InvalidRequest = auto() InternalError = auto() + class McpError(Exception): """Base exception class for MCP errors""" + def __init__(self, code: ErrorCode, message: str): self.code = code self.message = message super().__init__(message) + DEFAULT_CONFIG = { "server": { "host": "localhost", @@ -653,13 +658,17 @@ def test_generate_concurrent() -> None: try: async with session.post(url, json=data) as response: if response.status != 200: - error_msg = f"Request {request_id} failed with status {response.status}" + error_msg = ( + f"Request {request_id} failed with status {response.status}" + ) if OutputControl.is_verbose(): print(f"\n{error_msg}") raise McpError(ErrorCode.InternalError, error_msg) result = await response.json() if "error" in result: - error_msg = f"Request {request_id} returned error: {result['error']}" + error_msg = ( + f"Request {request_id} returned error: {result['error']}" + ) if OutputControl.is_verbose(): print(f"\n{error_msg}") raise McpError(ErrorCode.InternalError, error_msg) @@ -672,21 +681,23 @@ def test_generate_concurrent() -> None: async def run_concurrent_requests(): prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] - + async with get_session() as session: - tasks = [make_request(session, prompt, i+1) for i, prompt in enumerate(prompts)] + tasks = [ + make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts) + ] results = await asyncio.gather(*tasks, return_exceptions=True) - + # 收集成功和失败的结果 success_results = [] error_messages = [] - + for i, result in enumerate(results): if isinstance(result, Exception): error_messages.append(f"Request {i+1} failed: {str(result)}") else: - success_results.append((i+1, result)) - + success_results.append((i + 1, result)) + # 如果有任何错误,在打印完所有结果后抛出异常 if error_messages: # 先打印成功的结果 @@ -694,14 +705,14 @@ def test_generate_concurrent() -> None: if OutputControl.is_verbose(): print(f"\nRequest {req_id} succeeded:") print_json_response(result) - + # 打印所有错误信息 error_summary = "\n".join(error_messages) raise McpError( ErrorCode.InternalError, - f"Some concurrent requests failed:\n{error_summary}" + f"Some concurrent requests failed:\n{error_summary}", ) - + return results if OutputControl.is_verbose(): @@ -714,7 +725,7 @@ def test_generate_concurrent() -> None: for i, result in enumerate(results, 1): print(f"\nRequest {i} result:") print_json_response(result) - except McpError as e: + except McpError: # 错误信息已经在之前打印过了,这里直接抛出 raise From 52f4d97172b64c13866aa2cc769564c922359024 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 04:53:05 +0800 Subject: [PATCH 09/12] Fix timing calculation logic in OllamaAPI stream generators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Initialize first_chunk_time as None • Set timing only when first chunk arrives --- lightrag/api/ollama_api.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index c6f40879..132601c3 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,14 +203,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = None + last_chunk_time = time.time_ns() total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts + first_chunk_time = last_chunk_time last_chunk_time = time.time_ns() total_response = response @@ -282,7 +283,8 @@ class OllamaAPI: } yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return - + if first_chunk_time is None: + first_chunk_time = last_chunk_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time @@ -407,14 +409,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = time.time_ns() - last_chunk_time = first_chunk_time + first_chunk_time = None + last_chunk_time = time.time_ns() total_response = "" try: # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts + first_chunk_time = last_chunk_time last_chunk_time = time.time_ns() total_response = response @@ -499,6 +502,8 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return + if first_chunk_time is None: + first_chunk_time = last_chunk_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time From f83bd765ea501092268ee6015c73ac1ae60d75eb Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 10:42:49 +0800 Subject: [PATCH 10/12] fix: improve timing accuracy and variable scoping in OllamaAPI --- lightrag/api/ollama_api.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 132601c3..edf97993 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -203,15 +203,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = time.time_ns() - total_response = "" - try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = last_chunk_time + first_chunk_time = start_time last_chunk_time = time.time_ns() total_response = response @@ -284,7 +284,7 @@ class OllamaAPI: yield f"{json.dumps(final_data, ensure_ascii=False)}\n" return if first_chunk_time is None: - first_chunk_time = last_chunk_time + first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time @@ -409,15 +409,15 @@ class OllamaAPI: ) async def stream_generator(): - first_chunk_time = None - last_chunk_time = time.time_ns() - total_response = "" - try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + # Ensure response is an async generator if isinstance(response, str): # If it's a string, send in two parts - first_chunk_time = last_chunk_time + first_chunk_time = start_time last_chunk_time = time.time_ns() total_response = response @@ -503,7 +503,7 @@ class OllamaAPI: return if first_chunk_time is None: - first_chunk_time = last_chunk_time + first_chunk_time = start_time completion_tokens = estimate_tokens(total_response) total_time = last_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time From e756aca3a22d96b8d400822d6687a4a5070762bb Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 23:10:15 +0800 Subject: [PATCH 11/12] Remove Chinese comments and added English comments for clarity --- test_lightrag_ollama_chat.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py index cfd7557f..80038928 100644 --- a/test_lightrag_ollama_chat.py +++ b/test_lightrag_ollama_chat.py @@ -688,7 +688,6 @@ def test_generate_concurrent() -> None: ] results = await asyncio.gather(*tasks, return_exceptions=True) - # 收集成功和失败的结果 success_results = [] error_messages = [] @@ -698,15 +697,12 @@ def test_generate_concurrent() -> None: else: success_results.append((i + 1, result)) - # 如果有任何错误,在打印完所有结果后抛出异常 if error_messages: - # 先打印成功的结果 for req_id, result in success_results: if OutputControl.is_verbose(): print(f"\nRequest {req_id} succeeded:") print_json_response(result) - # 打印所有错误信息 error_summary = "\n".join(error_messages) raise McpError( ErrorCode.InternalError, @@ -721,12 +717,12 @@ def test_generate_concurrent() -> None: # Run concurrent requests try: results = asyncio.run(run_concurrent_requests()) - # 如果没有异常,打印所有成功的结果 + # all success, print out results for i, result in enumerate(results, 1): print(f"\nRequest {i} result:") print_json_response(result) except McpError: - # 错误信息已经在之前打印过了,这里直接抛出 + # error message already printed raise From a61db0852a158a8f6bc2c9d5b23c5b5a082beeb3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 6 Feb 2025 23:12:35 +0800 Subject: [PATCH 12/12] Fix linting --- lightrag/llm/openai.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index ca451bcf..535d665c 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -78,10 +78,13 @@ from lightrag.api import __api_version__ import numpy as np from typing import Union + class InvalidResponseError(Exception): """Custom exception class for triggering retry mechanism""" + pass + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -105,10 +108,11 @@ async def openai_complete_if_cache( default_headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json" + "Content-Type": "application/json", } openai_async_client = ( - AsyncOpenAI(default_headers=default_headers) if base_url is None + AsyncOpenAI(default_headers=default_headers) + if base_url is None else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) kwargs.pop("hashing_kv", None) @@ -295,10 +299,11 @@ async def openai_embed( default_headers = { "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json" + "Content-Type": "application/json", } openai_async_client = ( - AsyncOpenAI(default_headers=default_headers) if base_url is None + AsyncOpenAI(default_headers=default_headers) + if base_url is None else AsyncOpenAI(base_url=base_url, default_headers=default_headers) ) response = await openai_async_client.embeddings.create(