Merge pull request #644 from danielaskdd/Add-Ollama-generate-API-support

Add ollama generate api support
This commit is contained in:
zrguo
2025-01-25 01:52:59 +08:00
committed by GitHub
3 changed files with 409 additions and 31 deletions

View File

@@ -94,6 +94,7 @@ For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode q
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
To prevent Open WebUI from using LightRAG when generating conversation titles, go to Admin Panel > Interface > Set Task Model and change both Local Models and External Models to any option except "Current Model".
## Configuration

View File

@@ -533,6 +533,7 @@ class OllamaChatRequest(BaseModel):
messages: List[OllamaMessage]
stream: bool = True # Default to streaming mode
options: Optional[Dict[str, Any]] = None
system: Optional[str] = None
class OllamaChatResponse(BaseModel):
@@ -542,6 +543,28 @@ class OllamaChatResponse(BaseModel):
done: bool
class OllamaGenerateRequest(BaseModel):
model: str = LIGHTRAG_MODEL
prompt: str
system: Optional[str] = None
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaGenerateResponse(BaseModel):
model: str
created_at: str
response: str
done: bool
context: Optional[List[int]]
total_duration: Optional[int]
load_duration: Optional[int]
prompt_eval_count: Optional[int]
prompt_eval_duration: Optional[int]
eval_count: Optional[int]
eval_duration: Optional[int]
class OllamaVersionResponse(BaseModel):
version: str
@@ -1417,6 +1440,145 @@ def create_app(args):
return query, SearchMode.hybrid
@app.post("/api/generate")
async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests"""
try:
query = request.prompt
start_time = time.time_ns()
prompt_tokens = estimate_tokens(query)
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.llm_model_func(
query, stream=True, **rag.llm_model_kwargs
)
async def stream_generator():
try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"response": response,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return
except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}")
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
response_text = await rag.llm_model_func(
query, stream=False, **rag.llm_model_kwargs
)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"response": str(response_text),
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/chat")
async def chat(raw_request: Request, request: OllamaChatRequest):
"""Handle chat completion requests"""
@@ -1429,16 +1591,12 @@ def create_app(args):
# Get the last message as query
query = messages[-1].content
# 解析查询模式
# Check for query prefix
cleaned_query, mode = parse_query_mode(query)
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询
query_param = QueryParam(
mode=mode, stream=request.stream, only_need_context=False
)
@@ -1549,7 +1707,21 @@ def create_app(args):
)
else:
first_chunk_time = time.time_ns()
response_text = await rag.aquery(cleaned_query, param=query_param)
# Determine if the request is from Open WebUI's session title and session keyword generation task
match_result = re.search(
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
)
if match_result:
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
response_text = await rag.llm_model_func(
cleaned_query, stream=False, **rag.llm_model_kwargs
)
else:
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
if not response_text:

View File

@@ -108,7 +108,10 @@ DEFAULT_CONFIG = {
"max_retries": 3,
"retry_delay": 1,
},
"test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
"test_cases": {
"basic": {"query": "唐僧有几个徒弟"},
"generate": {"query": "电视剧西游记导演是谁"},
},
}
@@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
CONFIG = load_config()
def get_base_url() -> str:
"""Return the base URL"""
def get_base_url(endpoint: str = "chat") -> str:
"""Return the base URL for specified endpoint
Args:
endpoint: API endpoint name (chat or generate)
Returns:
Complete URL for the endpoint
"""
server = CONFIG["server"]
return f"http://{server['host']}:{server['port']}/api/chat"
return f"http://{server['host']}:{server['port']}/api/{endpoint}"
def create_request_data(
def create_chat_request_data(
content: str, stream: bool = False, model: str = None
) -> Dict[str, Any]:
"""Create basic request data
"""Create chat request data
Args:
content: User message content
stream: Whether to use streaming response
model: Model name
Returns:
Dictionary containing complete request data
Dictionary containing complete chat request data
"""
return {
"model": model or CONFIG["server"]["model"],
@@ -198,6 +206,35 @@ def create_request_data(
}
def create_generate_request_data(
prompt: str,
system: str = None,
stream: bool = False,
model: str = None,
options: Dict[str, Any] = None,
) -> Dict[str, Any]:
"""Create generate request data
Args:
prompt: Generation prompt
system: System prompt
stream: Whether to use streaming response
model: Model name
options: Additional options
Returns:
Dictionary containing complete generate request data
"""
data = {
"model": model or CONFIG["server"]["model"],
"prompt": prompt,
"stream": stream,
}
if system:
data["system"] = system
if options:
data["options"] = options
return data
# Global test statistics
STATS = TestStats()
@@ -219,10 +256,12 @@ def run_test(func: Callable, name: str) -> None:
raise
def test_non_stream_chat():
def test_non_stream_chat() -> None:
"""Test non-streaming call to /api/chat endpoint"""
url = get_base_url()
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
data = create_chat_request_data(
CONFIG["test_cases"]["basic"]["query"], stream=False
)
# Send request
response = make_request(url, data)
@@ -239,7 +278,7 @@ def test_non_stream_chat():
)
def test_stream_chat():
def test_stream_chat() -> None:
"""Test streaming call to /api/chat endpoint
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
@@ -258,7 +297,7 @@ def test_stream_chat():
The last message will contain performance statistics, with done set to true.
"""
url = get_base_url()
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
# Send request and get streaming response
response = make_request(url, data, stream=True)
@@ -295,7 +334,7 @@ def test_stream_chat():
print()
def test_query_modes():
def test_query_modes() -> None:
"""Test different query mode prefixes
Supported query modes:
@@ -313,7 +352,7 @@ def test_query_modes():
for mode in modes:
if OutputControl.is_verbose():
print(f"\n=== Testing /{mode} mode ===")
data = create_request_data(
data = create_chat_request_data(
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
)
@@ -354,7 +393,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
return error_data.get(error_type, error_data["empty_messages"])
def test_stream_error_handling():
def test_stream_error_handling() -> None:
"""Test error handling for streaming responses
Test scenarios:
@@ -400,7 +439,7 @@ def test_stream_error_handling():
response.close()
def test_error_handling():
def test_error_handling() -> None:
"""Test error handling for non-streaming responses
Test scenarios:
@@ -447,6 +486,165 @@ def test_error_handling():
print_json_response(response.json(), "Error message")
def test_non_stream_generate() -> None:
"""Test non-streaming call to /api/generate endpoint"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"], stream=False
)
# Send request
response = make_request(url, data)
# Print response
if OutputControl.is_verbose():
print("\n=== Non-streaming generate response ===")
response_json = response.json()
# Print response content
print_json_response(
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"],
},
"Response content",
)
def test_stream_generate() -> None:
"""Test streaming call to /api/generate endpoint"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"], stream=True
)
# Send request and get streaming response
response = make_request(url, data, stream=True)
if OutputControl.is_verbose():
print("\n=== Streaming generate response ===")
output_buffer = []
try:
for line in response.iter_lines():
if line: # Skip empty lines
try:
# Decode and parse JSON
data = json.loads(line.decode("utf-8"))
if data.get("done", True): # If it's the completion marker
if (
"total_duration" in data
): # Final performance statistics message
break
else: # Normal content message
content = data.get("response", "")
if content: # Only collect non-empty content
output_buffer.append(content)
print(
content, end="", flush=True
) # Print content in real-time
except json.JSONDecodeError:
print("Error decoding JSON from response line")
finally:
response.close() # Ensure the response connection is closed
# Print a newline
print()
def test_generate_with_system() -> None:
"""Test generate with system prompt"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"],
system="你是一个知识渊博的助手",
stream=False,
)
# Send request
response = make_request(url, data)
# Print response
if OutputControl.is_verbose():
print("\n=== Generate with system prompt response ===")
response_json = response.json()
# Print response content
print_json_response(
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"],
},
"Response content",
)
def test_generate_error_handling() -> None:
"""Test error handling for generate endpoint"""
url = get_base_url("generate")
# Test empty prompt
if OutputControl.is_verbose():
print("\n=== Testing empty prompt ===")
data = create_generate_request_data("", stream=False)
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# Test invalid options
if OutputControl.is_verbose():
print("\n=== Testing invalid options ===")
data = create_generate_request_data(
CONFIG["test_cases"]["basic"]["query"],
options={"invalid_option": "value"},
stream=False,
)
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
def test_generate_concurrent() -> None:
"""Test concurrent generate requests"""
import asyncio
import aiohttp
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_session():
async with aiohttp.ClientSession() as session:
yield session
async def make_request(session, prompt: str):
url = get_base_url("generate")
data = create_generate_request_data(prompt, stream=False)
try:
async with session.post(url, json=data) as response:
return await response.json()
except Exception as e:
return {"error": str(e)}
async def run_concurrent_requests():
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
async with get_session() as session:
tasks = [make_request(session, prompt) for prompt in prompts]
results = await asyncio.gather(*tasks)
return results
if OutputControl.is_verbose():
print("\n=== Testing concurrent generate requests ===")
# Run concurrent requests
results = asyncio.run(run_concurrent_requests())
# Print results
for i, result in enumerate(results, 1):
print(f"\nRequest {i} result:")
print_json_response(result)
def get_test_cases() -> Dict[str, Callable]:
"""Get all available test cases
Returns:
@@ -458,6 +656,11 @@ def get_test_cases() -> Dict[str, Callable]:
"modes": test_query_modes,
"errors": test_error_handling,
"stream_errors": test_stream_error_handling,
"non_stream_generate": test_non_stream_generate,
"stream_generate": test_stream_generate,
"generate_with_system": test_generate_with_system,
"generate_errors": test_generate_error_handling,
"generate_concurrent": test_generate_concurrent,
}
@@ -544,18 +747,20 @@ if __name__ == "__main__":
if "all" in args.tests:
# Run all tests
if OutputControl.is_verbose():
print("\nBasic Functionality Tests】")
run_test(test_non_stream_chat, "Non-streaming Call Test")
run_test(test_stream_chat, "Streaming Call Test")
print("\nChat API Tests】")
run_test(test_non_stream_chat, "Non-streaming Chat Test")
run_test(test_stream_chat, "Streaming Chat Test")
run_test(test_query_modes, "Chat Query Mode Test")
run_test(test_error_handling, "Chat Error Handling Test")
run_test(test_stream_error_handling, "Chat Streaming Error Test")
if OutputControl.is_verbose():
print("\nQuery Mode Tests】")
run_test(test_query_modes, "Query Mode Test")
if OutputControl.is_verbose():
print("\nError Handling Tests】")
run_test(test_error_handling, "Error Handling Test")
run_test(test_stream_error_handling, "Streaming Error Handling Test")
print("\nGenerate API Tests】")
run_test(test_non_stream_generate, "Non-streaming Generate Test")
run_test(test_stream_generate, "Streaming Generate Test")
run_test(test_generate_with_system, "Generate with System Prompt Test")
run_test(test_generate_error_handling, "Generate Error Handling Test")
run_test(test_generate_concurrent, "Generate Concurrent Test")
else:
# Run specified tests
for test_name in args.tests: