Split the Ollama API implementation to a separated file
This commit is contained in:
@@ -4,28 +4,20 @@ from fastapi import (
|
||||
File,
|
||||
UploadFile,
|
||||
Form,
|
||||
Request,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
# Backend (Python)
|
||||
# Add this to store progress globally
|
||||
from typing import Dict
|
||||
import threading
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
import json
|
||||
import re
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import argparse
|
||||
import time
|
||||
import re
|
||||
from typing import List, Any, Optional, Union
|
||||
from typing import List, Any, Optional, Union, Dict
|
||||
from pydantic import BaseModel
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.api import __api_version__
|
||||
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
@@ -34,20 +26,30 @@ import aiofiles
|
||||
from ascii_colors import trace_exception, ASCIIColors
|
||||
import sys
|
||||
import configparser
|
||||
|
||||
from fastapi import Depends, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
import pipmaster as pm
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from .ollama_api import (
|
||||
OllamaAPI,
|
||||
)
|
||||
from .ollama_api import ollama_server_infos
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
class RAGStorageConfig:
|
||||
KV_STORAGE = "JsonKVStorage"
|
||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||
GRAPH_STORAGE = "NetworkXStorage"
|
||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||
|
||||
# Initialize rag storage config
|
||||
rag_storage_config = RAGStorageConfig()
|
||||
|
||||
# Global progress tracker
|
||||
scan_progress: Dict = {
|
||||
"is_scanning": False,
|
||||
@@ -76,24 +78,6 @@ def estimate_tokens(text: str) -> int:
|
||||
return int(tokens)
|
||||
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
||||
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
|
||||
KV_STORAGE = "JsonKVStorage"
|
||||
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
|
||||
GRAPH_STORAGE = "NetworkXStorage"
|
||||
VECTOR_STORAGE = "NanoVectorDBStorage"
|
||||
|
||||
|
||||
# Add infos
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
# read config.ini
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
@@ -101,8 +85,8 @@ config.read("config.ini", "utf-8")
|
||||
redis_uri = config.get("redis", "uri", fallback=None)
|
||||
if redis_uri:
|
||||
os.environ["REDIS_URI"] = redis_uri
|
||||
ollama_server_infos.KV_STORAGE = "RedisKVStorage"
|
||||
ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage"
|
||||
rag_storage_config.KV_STORAGE = "RedisKVStorage"
|
||||
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
|
||||
|
||||
# Neo4j config
|
||||
neo4j_uri = config.get("neo4j", "uri", fallback=None)
|
||||
@@ -112,7 +96,7 @@ if neo4j_uri:
|
||||
os.environ["NEO4J_URI"] = neo4j_uri
|
||||
os.environ["NEO4J_USERNAME"] = neo4j_username
|
||||
os.environ["NEO4J_PASSWORD"] = neo4j_password
|
||||
ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage"
|
||||
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
|
||||
|
||||
# Milvus config
|
||||
milvus_uri = config.get("milvus", "uri", fallback=None)
|
||||
@@ -124,7 +108,7 @@ if milvus_uri:
|
||||
os.environ["MILVUS_USER"] = milvus_user
|
||||
os.environ["MILVUS_PASSWORD"] = milvus_password
|
||||
os.environ["MILVUS_DB_NAME"] = milvus_db_name
|
||||
ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge"
|
||||
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
|
||||
|
||||
# MongoDB config
|
||||
mongo_uri = config.get("mongodb", "uri", fallback=None)
|
||||
@@ -132,8 +116,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None)
|
||||
if mongo_uri:
|
||||
os.environ["MONGO_URI"] = mongo_uri
|
||||
os.environ["MONGO_DATABASE"] = mongo_database
|
||||
ollama_server_infos.KV_STORAGE = "MongoKVStorage"
|
||||
ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage"
|
||||
rag_storage_config.KV_STORAGE = "MongoKVStorage"
|
||||
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
@@ -535,6 +519,7 @@ def parse_args() -> argparse.Namespace:
|
||||
help="Cosine similarity threshold (default: from env or 0.4)",
|
||||
)
|
||||
|
||||
# Ollama model name
|
||||
parser.add_argument(
|
||||
"--simulated-model-name",
|
||||
type=str,
|
||||
@@ -599,84 +584,13 @@ class DocumentManager:
|
||||
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
|
||||
|
||||
|
||||
# Pydantic models
|
||||
# LightRAG query mode
|
||||
class SearchMode(str, Enum):
|
||||
naive = "naive"
|
||||
local = "local"
|
||||
global_ = "global"
|
||||
hybrid = "hybrid"
|
||||
mix = "mix"
|
||||
bypass = "bypass"
|
||||
|
||||
|
||||
class OllamaMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
|
||||
class OllamaChatRequest(BaseModel):
|
||||
model: str = ollama_server_infos.LIGHTRAG_MODEL
|
||||
messages: List[OllamaMessage]
|
||||
stream: bool = True # Default to streaming mode
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
system: Optional[str] = None
|
||||
|
||||
|
||||
class OllamaChatResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
message: OllamaMessage
|
||||
done: bool
|
||||
|
||||
|
||||
class OllamaGenerateRequest(BaseModel):
|
||||
model: str = ollama_server_infos.LIGHTRAG_MODEL
|
||||
prompt: str
|
||||
system: Optional[str] = None
|
||||
stream: bool = False
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class OllamaGenerateResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
response: str
|
||||
done: bool
|
||||
context: Optional[List[int]]
|
||||
total_duration: Optional[int]
|
||||
load_duration: Optional[int]
|
||||
prompt_eval_count: Optional[int]
|
||||
prompt_eval_duration: Optional[int]
|
||||
eval_count: Optional[int]
|
||||
eval_duration: Optional[int]
|
||||
|
||||
|
||||
class OllamaVersionResponse(BaseModel):
|
||||
version: str
|
||||
|
||||
|
||||
class OllamaModelDetails(BaseModel):
|
||||
parent_model: str
|
||||
format: str
|
||||
family: str
|
||||
families: List[str]
|
||||
parameter_size: str
|
||||
quantization_level: str
|
||||
|
||||
|
||||
class OllamaModel(BaseModel):
|
||||
name: str
|
||||
model: str
|
||||
size: int
|
||||
digest: str
|
||||
modified_at: str
|
||||
details: OllamaModelDetails
|
||||
|
||||
|
||||
class OllamaTagResponse(BaseModel):
|
||||
models: List[OllamaModel]
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
@@ -920,10 +834,10 @@ def create_app(args):
|
||||
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
||||
else {},
|
||||
embedding_func=embedding_func,
|
||||
kv_storage=ollama_server_infos.KV_STORAGE,
|
||||
graph_storage=ollama_server_infos.GRAPH_STORAGE,
|
||||
vector_storage=ollama_server_infos.VECTOR_STORAGE,
|
||||
doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
|
||||
kv_storage=rag_storage_config.KV_STORAGE,
|
||||
graph_storage=rag_storage_config.GRAPH_STORAGE,
|
||||
vector_storage=rag_storage_config.VECTOR_STORAGE,
|
||||
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
@@ -949,10 +863,10 @@ def create_app(args):
|
||||
llm_model_max_async=args.max_async,
|
||||
llm_model_max_token_size=args.max_tokens,
|
||||
embedding_func=embedding_func,
|
||||
kv_storage=ollama_server_infos.KV_STORAGE,
|
||||
graph_storage=ollama_server_infos.GRAPH_STORAGE,
|
||||
vector_storage=ollama_server_infos.VECTOR_STORAGE,
|
||||
doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
|
||||
kv_storage=rag_storage_config.KV_STORAGE,
|
||||
graph_storage=rag_storage_config.GRAPH_STORAGE,
|
||||
vector_storage=rag_storage_config.VECTOR_STORAGE,
|
||||
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": args.cosine_threshold
|
||||
},
|
||||
@@ -1475,450 +1389,9 @@ def create_app(args):
|
||||
async def get_graphs(label: str):
|
||||
return await rag.get_graps(nodel_label=label, max_depth=100)
|
||||
|
||||
# Ollama compatible API endpoints
|
||||
# -------------------------------------------------
|
||||
@app.get("/api/version")
|
||||
async def get_version():
|
||||
"""Get Ollama version information"""
|
||||
return OllamaVersionResponse(version="0.5.4")
|
||||
|
||||
@app.get("/api/tags")
|
||||
async def get_tags():
|
||||
"""Retrun available models acting as an Ollama server"""
|
||||
return OllamaTagResponse(
|
||||
models=[
|
||||
{
|
||||
"name": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"size": ollama_server_infos.LIGHTRAG_SIZE,
|
||||
"digest": ollama_server_infos.LIGHTRAG_DIGEST,
|
||||
"modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": ollama_server_infos.LIGHTRAG_NAME,
|
||||
"families": [ollama_server_infos.LIGHTRAG_NAME],
|
||||
"parameter_size": "13B",
|
||||
"quantization_level": "Q4_0",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
||||
"""Parse query prefix to determine search mode
|
||||
Returns tuple of (cleaned_query, search_mode)
|
||||
"""
|
||||
mode_map = {
|
||||
"/local ": SearchMode.local,
|
||||
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
||||
"/naive ": SearchMode.naive,
|
||||
"/hybrid ": SearchMode.hybrid,
|
||||
"/mix ": SearchMode.mix,
|
||||
"/bypass ": SearchMode.bypass,
|
||||
}
|
||||
|
||||
for prefix, mode in mode_map.items():
|
||||
if query.startswith(prefix):
|
||||
# After removing prefix an leading spaces
|
||||
cleaned_query = query[len(prefix) :].lstrip()
|
||||
return cleaned_query, mode
|
||||
|
||||
return query, SearchMode.hybrid
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||
"""Handle generate completion requests acting as an Ollama model
|
||||
For compatiblity purpuse, the request is not processed by LightRAG,
|
||||
and will be handled by underlying LLM model.
|
||||
"""
|
||||
try:
|
||||
query = request.prompt
|
||||
start_time = time.time_ns()
|
||||
prompt_tokens = estimate_tokens(query)
|
||||
|
||||
if request.system:
|
||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
|
||||
if request.stream:
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
response = await rag.llm_model_func(
|
||||
query, stream=True, **rag.llm_model_kwargs
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
try:
|
||||
first_chunk_time = None
|
||||
last_chunk_time = None
|
||||
total_response = ""
|
||||
|
||||
# Ensure response is an async generator
|
||||
if isinstance(response, str):
|
||||
# If it's a string, send in two parts
|
||||
first_chunk_time = time.time_ns()
|
||||
last_chunk_time = first_chunk_time
|
||||
total_response = response
|
||||
|
||||
data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": response,
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
else:
|
||||
async for chunk in response:
|
||||
if chunk:
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
total_response += chunk
|
||||
data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": chunk,
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in stream_generator: {str(e)}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
},
|
||||
)
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
response_text = await rag.llm_model_func(
|
||||
query, stream=False, **rag.llm_model_kwargs
|
||||
)
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
if not response_text:
|
||||
response_text = "No response generated"
|
||||
|
||||
completion_tokens = estimate_tokens(str(response_text))
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
return {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": str(response_text),
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||
"""Process chat completion requests acting as an Ollama model
|
||||
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
||||
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
||||
"""
|
||||
try:
|
||||
# Get all messages
|
||||
messages = request.messages
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="No messages provided")
|
||||
|
||||
# Get the last message as query and previous messages as history
|
||||
query = messages[-1].content
|
||||
# Convert OllamaMessage objects to dictionaries
|
||||
conversation_history = [
|
||||
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
|
||||
]
|
||||
|
||||
# Check for query prefix
|
||||
cleaned_query, mode = parse_query_mode(query)
|
||||
|
||||
start_time = time.time_ns()
|
||||
prompt_tokens = estimate_tokens(cleaned_query)
|
||||
|
||||
param_dict = {
|
||||
"mode": mode,
|
||||
"stream": request.stream,
|
||||
"only_need_context": False,
|
||||
"conversation_history": conversation_history,
|
||||
"top_k": args.top_k,
|
||||
}
|
||||
|
||||
if args.history_turns is not None:
|
||||
param_dict["history_turns"] = args.history_turns
|
||||
|
||||
query_param = QueryParam(**param_dict)
|
||||
|
||||
if request.stream:
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# Determine if the request is prefix with "/bypass"
|
||||
if mode == SearchMode.bypass:
|
||||
if request.system:
|
||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
response = await rag.llm_model_func(
|
||||
cleaned_query,
|
||||
stream=True,
|
||||
history_messages=conversation_history,
|
||||
**rag.llm_model_kwargs,
|
||||
)
|
||||
else:
|
||||
response = await rag.aquery( # Need await to get async generator
|
||||
cleaned_query, param=query_param
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
first_chunk_time = None
|
||||
last_chunk_time = None
|
||||
total_response = ""
|
||||
|
||||
try:
|
||||
# Ensure response is an async generator
|
||||
if isinstance(response, str):
|
||||
# If it's a string, send in two parts
|
||||
first_chunk_time = time.time_ns()
|
||||
last_chunk_time = first_chunk_time
|
||||
total_response = response
|
||||
|
||||
data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
"images": None,
|
||||
},
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
else:
|
||||
try:
|
||||
async for chunk in response:
|
||||
if chunk:
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
total_response += chunk
|
||||
data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": chunk,
|
||||
"images": None,
|
||||
},
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
except (asyncio.CancelledError, Exception) as e:
|
||||
error_msg = str(e)
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
error_msg = "Stream was cancelled by server"
|
||||
else:
|
||||
error_msg = f"Provider error: {error_msg}"
|
||||
|
||||
logging.error(f"Stream error: {error_msg}")
|
||||
|
||||
# Send error message to client
|
||||
error_data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": f"\n\nError: {error_msg}",
|
||||
"images": None,
|
||||
},
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||
|
||||
# Send final message to close the stream
|
||||
final_data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
}
|
||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||
return
|
||||
|
||||
if last_chunk_time is not None:
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in stream_generator: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
|
||||
# Send error message to client
|
||||
error_data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"error": {"code": "STREAM_ERROR", "message": error_msg},
|
||||
}
|
||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||
|
||||
# Ensure sending end marker
|
||||
final_data = {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
}
|
||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||
return
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
},
|
||||
)
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
|
||||
match_result = re.search(
|
||||
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
||||
)
|
||||
if match_result or mode == SearchMode.bypass:
|
||||
if request.system:
|
||||
rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
|
||||
response_text = await rag.llm_model_func(
|
||||
cleaned_query,
|
||||
stream=False,
|
||||
history_messages=conversation_history,
|
||||
**rag.llm_model_kwargs,
|
||||
)
|
||||
else:
|
||||
response_text = await rag.aquery(cleaned_query, param=query_param)
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
if not response_text:
|
||||
response_text = "No response generated"
|
||||
|
||||
completion_tokens = estimate_tokens(str(response_text))
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
return {
|
||||
"model": ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": str(response_text),
|
||||
"images": None,
|
||||
},
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Add Ollama API routes
|
||||
ollama_api = OllamaAPI(rag)
|
||||
app.include_router(ollama_api.router, prefix="/api")
|
||||
|
||||
@app.get("/documents", dependencies=[Depends(optional_api_key)])
|
||||
async def documents():
|
||||
@@ -1945,10 +1418,10 @@ def create_app(args):
|
||||
"embedding_binding_host": args.embedding_binding_host,
|
||||
"embedding_model": args.embedding_model,
|
||||
"max_tokens": args.max_tokens,
|
||||
"kv_storage": ollama_server_infos.KV_STORAGE,
|
||||
"doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE,
|
||||
"graph_storage": ollama_server_infos.GRAPH_STORAGE,
|
||||
"vector_storage": ollama_server_infos.VECTOR_STORAGE,
|
||||
"kv_storage": rag_storage_config.KV_STORAGE,
|
||||
"doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
|
||||
"graph_storage": rag_storage_config.GRAPH_STORAGE,
|
||||
"vector_storage": rag_storage_config.VECTOR_STORAGE,
|
||||
},
|
||||
}
|
||||
|
||||
|
554
lightrag/api/ollama_api.py
Normal file
554
lightrag/api/ollama_api.py
Normal file
@@ -0,0 +1,554 @@
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from enum import Enum
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
from ascii_colors import trace_exception
|
||||
from lightrag import LightRAG, QueryParam
|
||||
|
||||
class OllamaServerInfos:
|
||||
# Constants for emulated Ollama model information
|
||||
LIGHTRAG_NAME = "lightrag"
|
||||
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
|
||||
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
|
||||
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
||||
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||
|
||||
ollama_server_infos = OllamaServerInfos()
|
||||
|
||||
# query mode according to query prefix (bypass is not LightRAG quer mode)
|
||||
class SearchMode(str, Enum):
|
||||
naive = "naive"
|
||||
local = "local"
|
||||
global_ = "global"
|
||||
hybrid = "hybrid"
|
||||
mix = "mix"
|
||||
bypass = "bypass"
|
||||
|
||||
class OllamaMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
class OllamaChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[OllamaMessage]
|
||||
stream: bool = True
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
system: Optional[str] = None
|
||||
|
||||
class OllamaChatResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
message: OllamaMessage
|
||||
done: bool
|
||||
|
||||
class OllamaGenerateRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
system: Optional[str] = None
|
||||
stream: bool = False
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
||||
class OllamaGenerateResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
response: str
|
||||
done: bool
|
||||
context: Optional[List[int]]
|
||||
total_duration: Optional[int]
|
||||
load_duration: Optional[int]
|
||||
prompt_eval_count: Optional[int]
|
||||
prompt_eval_duration: Optional[int]
|
||||
eval_count: Optional[int]
|
||||
eval_duration: Optional[int]
|
||||
|
||||
class OllamaVersionResponse(BaseModel):
|
||||
version: str
|
||||
|
||||
class OllamaModelDetails(BaseModel):
|
||||
parent_model: str
|
||||
format: str
|
||||
family: str
|
||||
families: List[str]
|
||||
parameter_size: str
|
||||
quantization_level: str
|
||||
|
||||
class OllamaModel(BaseModel):
|
||||
name: str
|
||||
model: str
|
||||
size: int
|
||||
digest: str
|
||||
modified_at: str
|
||||
details: OllamaModelDetails
|
||||
|
||||
class OllamaTagResponse(BaseModel):
|
||||
models: List[OllamaModel]
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate the number of tokens in text
|
||||
Chinese characters: approximately 1.5 tokens per character
|
||||
English characters: approximately 0.25 tokens per character
|
||||
"""
|
||||
# Use regex to match Chinese and non-Chinese characters separately
|
||||
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
|
||||
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
|
||||
|
||||
# Calculate estimated token count
|
||||
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
||||
|
||||
return int(tokens)
|
||||
|
||||
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
||||
"""Parse query prefix to determine search mode
|
||||
Returns tuple of (cleaned_query, search_mode)
|
||||
"""
|
||||
mode_map = {
|
||||
"/local ": SearchMode.local,
|
||||
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
||||
"/naive ": SearchMode.naive,
|
||||
"/hybrid ": SearchMode.hybrid,
|
||||
"/mix ": SearchMode.mix,
|
||||
"/bypass ": SearchMode.bypass,
|
||||
}
|
||||
|
||||
for prefix, mode in mode_map.items():
|
||||
if query.startswith(prefix):
|
||||
# After removing prefix an leading spaces
|
||||
cleaned_query = query[len(prefix) :].lstrip()
|
||||
return cleaned_query, mode
|
||||
|
||||
return query, SearchMode.hybrid
|
||||
|
||||
class OllamaAPI:
|
||||
def __init__(self, rag: LightRAG):
|
||||
self.rag = rag
|
||||
self.ollama_server_infos = ollama_server_infos
|
||||
self.router = APIRouter()
|
||||
self.setup_routes()
|
||||
|
||||
def setup_routes(self):
|
||||
@self.router.get("/version")
|
||||
async def get_version():
|
||||
"""Get Ollama version information"""
|
||||
return OllamaVersionResponse(version="0.5.4")
|
||||
|
||||
@self.router.get("/tags")
|
||||
async def get_tags():
|
||||
"""Return available models acting as an Ollama server"""
|
||||
return OllamaTagResponse(
|
||||
models=[
|
||||
{
|
||||
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
|
||||
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
|
||||
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": self.ollama_server_infos.LIGHTRAG_NAME,
|
||||
"families": [self.ollama_server_infos.LIGHTRAG_NAME],
|
||||
"parameter_size": "13B",
|
||||
"quantization_level": "Q4_0",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
@self.router.post("/generate")
|
||||
async def generate(raw_request: Request, request: OllamaGenerateRequest):
|
||||
"""Handle generate completion requests acting as an Ollama model
|
||||
For compatibility purpose, the request is not processed by LightRAG,
|
||||
and will be handled by underlying LLM model.
|
||||
"""
|
||||
try:
|
||||
query = request.prompt
|
||||
start_time = time.time_ns()
|
||||
prompt_tokens = estimate_tokens(query)
|
||||
|
||||
if request.system:
|
||||
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
|
||||
if request.stream:
|
||||
response = await self.rag.llm_model_func(
|
||||
query, stream=True, **self.rag.llm_model_kwargs
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
try:
|
||||
first_chunk_time = None
|
||||
last_chunk_time = None
|
||||
total_response = ""
|
||||
|
||||
# Ensure response is an async generator
|
||||
if isinstance(response, str):
|
||||
# If it's a string, send in two parts
|
||||
first_chunk_time = time.time_ns()
|
||||
last_chunk_time = first_chunk_time
|
||||
total_response = response
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": response,
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
else:
|
||||
async for chunk in response:
|
||||
if chunk:
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
total_response += chunk
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": chunk,
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
},
|
||||
)
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
response_text = await self.rag.llm_model_func(
|
||||
query, stream=False, **self.rag.llm_model_kwargs
|
||||
)
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
if not response_text:
|
||||
response_text = "No response generated"
|
||||
|
||||
completion_tokens = estimate_tokens(str(response_text))
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
return {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"response": str(response_text),
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@self.router.post("/chat")
|
||||
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||
"""Process chat completion requests acting as an Ollama model
|
||||
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
|
||||
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
|
||||
"""
|
||||
try:
|
||||
# Get all messages
|
||||
messages = request.messages
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="No messages provided")
|
||||
|
||||
# Get the last message as query and previous messages as history
|
||||
query = messages[-1].content
|
||||
# Convert OllamaMessage objects to dictionaries
|
||||
conversation_history = [
|
||||
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
|
||||
]
|
||||
|
||||
# Check for query prefix
|
||||
cleaned_query, mode = parse_query_mode(query)
|
||||
|
||||
start_time = time.time_ns()
|
||||
prompt_tokens = estimate_tokens(cleaned_query)
|
||||
|
||||
param_dict = {
|
||||
"mode": mode,
|
||||
"stream": request.stream,
|
||||
"only_need_context": False,
|
||||
"conversation_history": conversation_history,
|
||||
"top_k": self.rag.args.top_k if hasattr(self.rag, 'args') else 50,
|
||||
}
|
||||
|
||||
if hasattr(self.rag, 'args') and self.rag.args.history_turns is not None:
|
||||
param_dict["history_turns"] = self.rag.args.history_turns
|
||||
|
||||
query_param = QueryParam(**param_dict)
|
||||
|
||||
if request.stream:
|
||||
# Determine if the request is prefix with "/bypass"
|
||||
if mode == SearchMode.bypass:
|
||||
if request.system:
|
||||
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
response = await self.rag.llm_model_func(
|
||||
cleaned_query,
|
||||
stream=True,
|
||||
history_messages=conversation_history,
|
||||
**self.rag.llm_model_kwargs,
|
||||
)
|
||||
else:
|
||||
response = await self.rag.aquery(
|
||||
cleaned_query, param=query_param
|
||||
)
|
||||
|
||||
async def stream_generator():
|
||||
first_chunk_time = None
|
||||
last_chunk_time = None
|
||||
total_response = ""
|
||||
|
||||
try:
|
||||
# Ensure response is an async generator
|
||||
if isinstance(response, str):
|
||||
# If it's a string, send in two parts
|
||||
first_chunk_time = time.time_ns()
|
||||
last_chunk_time = first_chunk_time
|
||||
total_response = response
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
"images": None,
|
||||
},
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
else:
|
||||
try:
|
||||
async for chunk in response:
|
||||
if chunk:
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
total_response += chunk
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": chunk,
|
||||
"images": None,
|
||||
},
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
except (asyncio.CancelledError, Exception) as e:
|
||||
error_msg = str(e)
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
error_msg = "Stream was cancelled by server"
|
||||
else:
|
||||
error_msg = f"Provider error: {error_msg}"
|
||||
|
||||
logging.error(f"Stream error: {error_msg}")
|
||||
|
||||
# Send error message to client
|
||||
error_data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": f"\n\nError: {error_msg}",
|
||||
"images": None,
|
||||
},
|
||||
"done": False,
|
||||
}
|
||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||
|
||||
# Send final message to close the stream
|
||||
final_data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
}
|
||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||
return
|
||||
|
||||
if last_chunk_time is not None:
|
||||
completion_tokens = estimate_tokens(total_response)
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in stream_generator: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
|
||||
# Send error message to client
|
||||
error_data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"error": {"code": "STREAM_ERROR", "message": error_msg},
|
||||
}
|
||||
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
|
||||
|
||||
# Ensure sending end marker
|
||||
final_data = {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"done": True,
|
||||
}
|
||||
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
|
||||
return
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Content-Type": "application/x-ndjson",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type",
|
||||
},
|
||||
)
|
||||
else:
|
||||
first_chunk_time = time.time_ns()
|
||||
|
||||
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
|
||||
match_result = re.search(
|
||||
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
|
||||
)
|
||||
if match_result or mode == SearchMode.bypass:
|
||||
if request.system:
|
||||
self.rag.llm_model_kwargs["system_prompt"] = request.system
|
||||
|
||||
response_text = await self.rag.llm_model_func(
|
||||
cleaned_query,
|
||||
stream=False,
|
||||
history_messages=conversation_history,
|
||||
**self.rag.llm_model_kwargs,
|
||||
)
|
||||
else:
|
||||
response_text = await self.rag.aquery(cleaned_query, param=query_param)
|
||||
|
||||
last_chunk_time = time.time_ns()
|
||||
|
||||
if not response_text:
|
||||
response_text = "No response generated"
|
||||
|
||||
completion_tokens = estimate_tokens(str(response_text))
|
||||
total_time = last_chunk_time - start_time
|
||||
prompt_eval_time = first_chunk_time - start_time
|
||||
eval_time = last_chunk_time - first_chunk_time
|
||||
|
||||
return {
|
||||
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
|
||||
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": str(response_text),
|
||||
"images": None,
|
||||
},
|
||||
"done": True,
|
||||
"total_duration": total_time,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tokens,
|
||||
"prompt_eval_duration": prompt_eval_time,
|
||||
"eval_count": completion_tokens,
|
||||
"eval_duration": eval_time,
|
||||
}
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
Reference in New Issue
Block a user