Merge pull request #723 from danielaskdd/improve-ollama-api-streaming

Improve error handling
This commit is contained in:
zrguo
2025-02-07 02:18:54 +08:00
committed by GitHub
4 changed files with 223 additions and 102 deletions

View File

@@ -205,14 +205,14 @@ class OllamaAPI:
async def stream_generator(): async def stream_generator():
try: try:
first_chunk_time = None first_chunk_time = None
last_chunk_time = None last_chunk_time = time.time_ns()
total_response = "" total_response = ""
# Ensure response is an async generator # Ensure response is an async generator
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send in two parts # If it's a string, send in two parts
first_chunk_time = time.time_ns() first_chunk_time = start_time
last_chunk_time = first_chunk_time last_chunk_time = time.time_ns()
total_response = response total_response = response
data = { data = {
@@ -241,22 +241,50 @@ class OllamaAPI:
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
else: else:
async for chunk in response: try:
if chunk: async for chunk in response:
if first_chunk_time is None: if chunk:
first_chunk_time = time.time_ns() if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns() last_chunk_time = time.time_ns()
total_response += chunk total_response += chunk
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": chunk, "response": chunk,
"done": False, "done": False,
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logging.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": f"\n\nError: {error_msg}",
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response) completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time
@@ -381,16 +409,16 @@ class OllamaAPI:
) )
async def stream_generator(): async def stream_generator():
first_chunk_time = None
last_chunk_time = None
total_response = ""
try: try:
first_chunk_time = None
last_chunk_time = time.time_ns()
total_response = ""
# Ensure response is an async generator # Ensure response is an async generator
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send in two parts # If it's a string, send in two parts
first_chunk_time = time.time_ns() first_chunk_time = start_time
last_chunk_time = first_chunk_time last_chunk_time = time.time_ns()
total_response = response total_response = response
data = { data = {
@@ -474,45 +502,29 @@ class OllamaAPI:
yield f"{json.dumps(final_data, ensure_ascii=False)}\n" yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return return
if last_chunk_time is not None: if first_chunk_time is None:
completion_tokens = estimate_tokens(total_response) first_chunk_time = start_time
total_time = last_chunk_time - start_time completion_tokens = estimate_tokens(total_response)
prompt_eval_time = first_chunk_time - start_time total_time = last_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = { data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL, "model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True, "done": True,
"total_duration": total_time, "total_duration": total_time,
"load_duration": 0, "load_duration": 0,
"prompt_eval_count": prompt_tokens, "prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time, "prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens, "eval_count": completion_tokens,
"eval_duration": eval_time, "eval_duration": eval_time,
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
except Exception as e: except Exception as e:
error_msg = f"Error in stream_generator: {str(e)}" trace_exception(e)
logging.error(error_msg) raise
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"error": {"code": "STREAM_ERROR", "message": error_msg},
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Ensure sending end marker
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
return StreamingResponse( return StreamingResponse(
stream_generator(), stream_generator(),

View File

@@ -66,6 +66,7 @@ from lightrag.exceptions import (
RateLimitError, RateLimitError,
APITimeoutError, APITimeoutError,
) )
from lightrag.api import __api_version__
import numpy as np import numpy as np
from typing import Union from typing import Union
@@ -91,11 +92,12 @@ async def ollama_model_if_cache(
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
headers = ( headers = {
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} "Content-Type": "application/json",
if api_key "User-Agent": f"LightRAG/{__api_version__}",
else {"Content-Type": "application/json"} }
) if api_key:
headers["Authorization"] = f"Bearer {api_key}"
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
messages = [] messages = []
if system_prompt: if system_prompt:
@@ -147,11 +149,12 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
headers = ( headers = {
{"Content-Type": "application/json", "Authorization": api_key} "Content-Type": "application/json",
if api_key "User-Agent": f"LightRAG/{__api_version__}",
else {"Content-Type": "application/json"} }
) if api_key:
headers["Authorization"] = api_key
kwargs["headers"] = headers kwargs["headers"] = headers
ollama_client = ollama.Client(**kwargs) ollama_client = ollama.Client(**kwargs)
data = ollama_client.embed(model=embed_model, input=texts) data = ollama_client.embed(model=embed_model, input=texts)

View File

@@ -73,16 +73,23 @@ from lightrag.utils import (
logger, logger,
) )
from lightrag.types import GPTKeywordExtractionFormat from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
import numpy as np import numpy as np
from typing import Union from typing import Union
class InvalidResponseError(Exception):
"""Custom exception class for triggering retry mechanism"""
pass
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type( retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError) (RateLimitError, APIConnectionError, APITimeoutError, InvalidResponseError)
), ),
) )
async def openai_complete_if_cache( async def openai_complete_if_cache(
@@ -99,8 +106,14 @@ async def openai_complete_if_cache(
if api_key: if api_key:
os.environ["OPENAI_API_KEY"] = api_key os.environ["OPENAI_API_KEY"] = api_key
default_headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = ( openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) AsyncOpenAI(default_headers=default_headers)
if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
) )
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None) kwargs.pop("keyword_extraction", None)
@@ -112,17 +125,35 @@ async def openai_complete_if_cache(
# 添加日志输出 # 添加日志输出
logger.debug("===== Query Input to LLM =====") logger.debug("===== Query Input to LLM =====")
logger.debug(f"Model: {model} Base URL: {base_url}")
logger.debug(f"Additional kwargs: {kwargs}")
logger.debug(f"Query: {prompt}") logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}") logger.debug(f"System prompt: {system_prompt}")
logger.debug("Full context:") # logger.debug(f"Messages: {messages}")
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse( try:
model=model, messages=messages, **kwargs if "response_format" in kwargs:
) response = await openai_async_client.beta.chat.completions.parse(
else: model=model, messages=messages, **kwargs
response = await openai_async_client.chat.completions.create( )
model=model, messages=messages, **kwargs else:
) response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
except APIConnectionError as e:
logger.error(f"OpenAI API Connection Error: {str(e)}")
raise
except RateLimitError as e:
logger.error(f"OpenAI API Rate Limit Error: {str(e)}")
raise
except APITimeoutError as e:
logger.error(f"OpenAI API Timeout Error: {str(e)}")
raise
except Exception as e:
logger.error(f"OpenAI API Call Failed: {str(e)}")
logger.error(f"Model: {model}")
logger.error(f"Request parameters: {kwargs}")
raise
if hasattr(response, "__aiter__"): if hasattr(response, "__aiter__"):
@@ -140,8 +171,23 @@ async def openai_complete_if_cache(
raise raise
return inner() return inner()
else: else:
if (
not response
or not response.choices
or not hasattr(response.choices[0], "message")
or not hasattr(response.choices[0].message, "content")
):
logger.error("Invalid response from OpenAI API")
raise InvalidResponseError("Invalid response from OpenAI API")
content = response.choices[0].message.content content = response.choices[0].message.content
if not content or content.strip() == "":
logger.error("Received empty content from OpenAI API")
raise InvalidResponseError("Received empty content from OpenAI API")
if r"\u" in content: if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8")) content = safe_unicode_decode(content.encode("utf-8"))
return content return content
@@ -251,8 +297,14 @@ async def openai_embed(
if api_key: if api_key:
os.environ["OPENAI_API_KEY"] = api_key os.environ["OPENAI_API_KEY"] = api_key
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = ( openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) AsyncOpenAI(default_headers=default_headers)
if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
) )
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float" model=model, input=texts, encoding_format="float"

View File

@@ -17,14 +17,32 @@ from typing import Dict, Any, Optional, List, Callable
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from enum import Enum, auto
class ErrorCode(Enum):
"""Error codes for MCP errors"""
InvalidRequest = auto()
InternalError = auto()
class McpError(Exception):
"""Base exception class for MCP errors"""
def __init__(self, code: ErrorCode, message: str):
self.code = code
self.message = message
super().__init__(message)
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"server": { "server": {
"host": "localhost", "host": "localhost",
"port": 9621, "port": 9621,
"model": "lightrag:latest", "model": "lightrag:latest",
"timeout": 120, "timeout": 300,
"max_retries": 3, "max_retries": 1,
"retry_delay": 1, "retry_delay": 1,
}, },
"test_cases": { "test_cases": {
@@ -527,14 +545,7 @@ def test_non_stream_generate() -> None:
response_json = response.json() response_json = response.json()
# Print response content # Print response content
print_json_response( print(json.dumps(response_json, ensure_ascii=False, indent=2))
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"],
},
"Response content",
)
def test_stream_generate() -> None: def test_stream_generate() -> None:
@@ -641,35 +652,78 @@ def test_generate_concurrent() -> None:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
yield session yield session
async def make_request(session, prompt: str): async def make_request(session, prompt: str, request_id: int):
url = get_base_url("generate") url = get_base_url("generate")
data = create_generate_request_data(prompt, stream=False) data = create_generate_request_data(prompt, stream=False)
try: try:
async with session.post(url, json=data) as response: async with session.post(url, json=data) as response:
if response.status != 200: if response.status != 200:
response.raise_for_status() error_msg = (
return await response.json() f"Request {request_id} failed with status {response.status}"
)
if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
result = await response.json()
if "error" in result:
error_msg = (
f"Request {request_id} returned error: {result['error']}"
)
if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
return result
except Exception as e: except Exception as e:
return {"error": str(e)} error_msg = f"Request {request_id} failed: {str(e)}"
if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
async def run_concurrent_requests(): async def run_concurrent_requests():
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
async with get_session() as session: async with get_session() as session:
tasks = [make_request(session, prompt) for prompt in prompts] tasks = [
results = await asyncio.gather(*tasks) make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
success_results = []
error_messages = []
for i, result in enumerate(results):
if isinstance(result, Exception):
error_messages.append(f"Request {i+1} failed: {str(result)}")
else:
success_results.append((i + 1, result))
if error_messages:
for req_id, result in success_results:
if OutputControl.is_verbose():
print(f"\nRequest {req_id} succeeded:")
print_json_response(result)
error_summary = "\n".join(error_messages)
raise McpError(
ErrorCode.InternalError,
f"Some concurrent requests failed:\n{error_summary}",
)
return results return results
if OutputControl.is_verbose(): if OutputControl.is_verbose():
print("\n=== Testing concurrent generate requests ===") print("\n=== Testing concurrent generate requests ===")
# Run concurrent requests # Run concurrent requests
results = asyncio.run(run_concurrent_requests()) try:
results = asyncio.run(run_concurrent_requests())
# Print results # all success, print out results
for i, result in enumerate(results, 1): for i, result in enumerate(results, 1):
print(f"\nRequest {i} result:") print(f"\nRequest {i} result:")
print_json_response(result) print_json_response(result)
except McpError:
# error message already printed
raise
def get_test_cases() -> Dict[str, Callable]: def get_test_cases() -> Dict[str, Callable]: