Merge tag 'time-temp' into improve-ollama-api-streaming

This commit is contained in:
yangdx
2025-02-06 23:00:32 +08:00
2 changed files with 147 additions and 70 deletions

View File

@@ -205,14 +205,14 @@ class OllamaAPI:
async def stream_generator(): async def stream_generator():
try: try:
first_chunk_time = None first_chunk_time = None
last_chunk_time = None last_chunk_time = time.time_ns()
total_response = "" total_response = ""
# Ensure response is an async generator # Ensure response is an async generator
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send in two parts # If it's a string, send in two parts
first_chunk_time = time.time_ns() first_chunk_time = start_time
last_chunk_time = first_chunk_time last_chunk_time = time.time_ns()
total_response = response total_response = response
data = { data = {
@@ -241,6 +241,7 @@ class OllamaAPI:
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
else: else:
try:
async for chunk in response: async for chunk in response:
if chunk: if chunk:
if first_chunk_time is None: if first_chunk_time is None:
@@ -256,7 +257,34 @@ class OllamaAPI:
"done": False, "done": False,
} }
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logging.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": f"\n\nError: {error_msg}",
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response) completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time
@@ -381,16 +409,16 @@ class OllamaAPI:
) )
async def stream_generator(): async def stream_generator():
try:
first_chunk_time = None first_chunk_time = None
last_chunk_time = None last_chunk_time = time.time_ns()
total_response = "" total_response = ""
try:
# Ensure response is an async generator # Ensure response is an async generator
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send in two parts # If it's a string, send in two parts
first_chunk_time = time.time_ns() first_chunk_time = start_time
last_chunk_time = first_chunk_time last_chunk_time = time.time_ns()
total_response = response total_response = response
data = { data = {
@@ -474,7 +502,8 @@ class OllamaAPI:
yield f"{json.dumps(final_data, ensure_ascii=False)}\n" yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return return
if last_chunk_time is not None: if first_chunk_time is None:
first_chunk_time = start_time
completion_tokens = estimate_tokens(total_response) completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time prompt_eval_time = first_chunk_time - start_time
@@ -494,25 +523,8 @@ class OllamaAPI:
yield f"{json.dumps(data, ensure_ascii=False)}\n" yield f"{json.dumps(data, ensure_ascii=False)}\n"
except Exception as e: except Exception as e:
error_msg = f"Error in stream_generator: {str(e)}" trace_exception(e)
logging.error(error_msg) raise
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"error": {"code": "STREAM_ERROR", "message": error_msg},
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Ensure sending end marker
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
return StreamingResponse( return StreamingResponse(
stream_generator(), stream_generator(),

View File

@@ -17,6 +17,24 @@ from typing import Dict, Any, Optional, List, Callable
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from enum import Enum, auto
class ErrorCode(Enum):
"""Error codes for MCP errors"""
InvalidRequest = auto()
InternalError = auto()
class McpError(Exception):
"""Base exception class for MCP errors"""
def __init__(self, code: ErrorCode, message: str):
self.code = code
self.message = message
super().__init__(message)
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"server": { "server": {
@@ -634,35 +652,82 @@ def test_generate_concurrent() -> None:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
yield session yield session
async def make_request(session, prompt: str): async def make_request(session, prompt: str, request_id: int):
url = get_base_url("generate") url = get_base_url("generate")
data = create_generate_request_data(prompt, stream=False) data = create_generate_request_data(prompt, stream=False)
try: try:
async with session.post(url, json=data) as response: async with session.post(url, json=data) as response:
if response.status != 200: if response.status != 200:
response.raise_for_status() error_msg = (
return await response.json() f"Request {request_id} failed with status {response.status}"
)
if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
result = await response.json()
if "error" in result:
error_msg = (
f"Request {request_id} returned error: {result['error']}"
)
if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
return result
except Exception as e: except Exception as e:
return {"error": str(e)} error_msg = f"Request {request_id} failed: {str(e)}"
if OutputControl.is_verbose():
print(f"\n{error_msg}")
raise McpError(ErrorCode.InternalError, error_msg)
async def run_concurrent_requests(): async def run_concurrent_requests():
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"] prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
async with get_session() as session: async with get_session() as session:
tasks = [make_request(session, prompt) for prompt in prompts] tasks = [
results = await asyncio.gather(*tasks) make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 收集成功和失败的结果
success_results = []
error_messages = []
for i, result in enumerate(results):
if isinstance(result, Exception):
error_messages.append(f"Request {i+1} failed: {str(result)}")
else:
success_results.append((i + 1, result))
# 如果有任何错误,在打印完所有结果后抛出异常
if error_messages:
# 先打印成功的结果
for req_id, result in success_results:
if OutputControl.is_verbose():
print(f"\nRequest {req_id} succeeded:")
print_json_response(result)
# 打印所有错误信息
error_summary = "\n".join(error_messages)
raise McpError(
ErrorCode.InternalError,
f"Some concurrent requests failed:\n{error_summary}",
)
return results return results
if OutputControl.is_verbose(): if OutputControl.is_verbose():
print("\n=== Testing concurrent generate requests ===") print("\n=== Testing concurrent generate requests ===")
# Run concurrent requests # Run concurrent requests
try:
results = asyncio.run(run_concurrent_requests()) results = asyncio.run(run_concurrent_requests())
# 如果没有异常,打印所有成功的结果
# Print results
for i, result in enumerate(results, 1): for i, result in enumerate(results, 1):
print(f"\nRequest {i} result:") print(f"\nRequest {i} result:")
print_json_response(result) print_json_response(result)
except McpError:
# 错误信息已经在之前打印过了,这里直接抛出
raise
def get_test_cases() -> Dict[str, Callable]: def get_test_cases() -> Dict[str, Callable]: