Split the Ollama API implementation to a separated file

This commit is contained in:
yangdx
2025-02-05 22:15:14 +08:00
parent 2bbf451fa0
commit f703334ce4
2 changed files with 594 additions and 567 deletions

View File

@@ -4,28 +4,20 @@ from fastapi import (
File,
UploadFile,
Form,
Request,
BackgroundTasks,
)
# Backend (Python)
# Add this to store progress globally
from typing import Dict
import threading
import asyncio
import json
import os
import json
import re
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import logging
import argparse
import time
import re
from typing import List, Any, Optional, Union
from typing import List, Any, Optional, Union, Dict
from pydantic import BaseModel
from lightrag import LightRAG, QueryParam
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from enum import Enum
from pathlib import Path
@@ -34,20 +26,30 @@ import aiofiles
from ascii_colors import trace_exception, ASCIIColors
import sys
import configparser
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm
from dotenv import load_dotenv
from .ollama_api import (
OllamaAPI,
)
from .ollama_api import ollama_server_infos
# Load environment variables
load_dotenv()
class RAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
# Initialize rag storage config
rag_storage_config = RAGStorageConfig()
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
@@ -76,24 +78,6 @@ def estimate_tokens(text: str) -> int:
return int(tokens)
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
# Add infos
ollama_server_infos = OllamaServerInfos()
# read config.ini
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -101,8 +85,8 @@ config.read("config.ini", "utf-8")
redis_uri = config.get("redis", "uri", fallback=None)
if redis_uri:
os.environ["REDIS_URI"] = redis_uri
ollama_server_infos.KV_STORAGE = "RedisKVStorage"
ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage"
rag_storage_config.KV_STORAGE = "RedisKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
# Neo4j config
neo4j_uri = config.get("neo4j", "uri", fallback=None)
@@ -112,7 +96,7 @@ if neo4j_uri:
os.environ["NEO4J_URI"] = neo4j_uri
os.environ["NEO4J_USERNAME"] = neo4j_username
os.environ["NEO4J_PASSWORD"] = neo4j_password
ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage"
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
# Milvus config
milvus_uri = config.get("milvus", "uri", fallback=None)
@@ -124,7 +108,7 @@ if milvus_uri:
os.environ["MILVUS_USER"] = milvus_user
os.environ["MILVUS_PASSWORD"] = milvus_password
os.environ["MILVUS_DB_NAME"] = milvus_db_name
ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge"
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
# MongoDB config
mongo_uri = config.get("mongodb", "uri", fallback=None)
@@ -132,8 +116,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None)
if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database
ollama_server_infos.KV_STORAGE = "MongoKVStorage"
ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage"
rag_storage_config.KV_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
def get_default_host(binding_type: str) -> str:
@@ -535,6 +519,7 @@ def parse_args() -> argparse.Namespace:
help="Cosine similarity threshold (default: from env or 0.4)",
)
# Ollama model name
parser.add_argument(
"--simulated-model-name",
type=str,
@@ -599,84 +584,13 @@ class DocumentManager:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models
# LightRAG query mode
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
mix = "mix"
bypass = "bypass"
class OllamaMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel):
model: str = ollama_server_infos.LIGHTRAG_MODEL
messages: List[OllamaMessage]
stream: bool = True # Default to streaming mode
options: Optional[Dict[str, Any]] = None
system: Optional[str] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaGenerateRequest(BaseModel):
model: str = ollama_server_infos.LIGHTRAG_MODEL
prompt: str
system: Optional[str] = None
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaGenerateResponse(BaseModel):
model: str
created_at: str
response: str
done: bool
context: Optional[List[int]]
total_duration: Optional[int]
load_duration: Optional[int]
prompt_eval_count: Optional[int]
prompt_eval_duration: Optional[int]
eval_count: Optional[int]
eval_duration: Optional[int]
class OllamaVersionResponse(BaseModel):
version: str
class OllamaModelDetails(BaseModel):
parent_model: str
format: str
family: str
families: List[str]
parameter_size: str
quantization_level: str
class OllamaModel(BaseModel):
name: str
model: str
size: int
digest: str
modified_at: str
details: OllamaModelDetails
class OllamaTagResponse(BaseModel):
models: List[OllamaModel]
class QueryRequest(BaseModel):
query: str
@@ -920,10 +834,10 @@ def create_app(args):
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
else {},
embedding_func=embedding_func,
kv_storage=ollama_server_infos.KV_STORAGE,
graph_storage=ollama_server_infos.GRAPH_STORAGE,
vector_storage=ollama_server_infos.VECTOR_STORAGE,
doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
kv_storage=rag_storage_config.KV_STORAGE,
graph_storage=rag_storage_config.GRAPH_STORAGE,
vector_storage=rag_storage_config.VECTOR_STORAGE,
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
@@ -949,10 +863,10 @@ def create_app(args):
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
embedding_func=embedding_func,
kv_storage=ollama_server_infos.KV_STORAGE,
graph_storage=ollama_server_infos.GRAPH_STORAGE,
vector_storage=ollama_server_infos.VECTOR_STORAGE,
doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE,
kv_storage=rag_storage_config.KV_STORAGE,
graph_storage=rag_storage_config.GRAPH_STORAGE,
vector_storage=rag_storage_config.VECTOR_STORAGE,
doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
@@ -1475,450 +1389,9 @@ def create_app(args):
async def get_graphs(label: str):
return await rag.get_graps(nodel_label=label, max_depth=100)
# Ollama compatible API endpoints
# -------------------------------------------------
@app.get("/api/version")
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4")
@app.get("/api/tags")
async def get_tags():
"""Retrun available models acting as an Ollama server"""
return OllamaTagResponse(
models=[
{
"name": ollama_server_infos.LIGHTRAG_MODEL,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"size": ollama_server_infos.LIGHTRAG_SIZE,
"digest": ollama_server_infos.LIGHTRAG_DIGEST,
"modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"details": {
"parent_model": "",
"format": "gguf",
"family": ollama_server_infos.LIGHTRAG_NAME,
"families": [ollama_server_infos.LIGHTRAG_NAME],
"parameter_size": "13B",
"quantization_level": "Q4_0",
},
}
]
)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode)
"""
mode_map = {
"/local ": SearchMode.local,
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid,
"/mix ": SearchMode.mix,
"/bypass ": SearchMode.bypass,
}
for prefix, mode in mode_map.items():
if query.startswith(prefix):
# After removing prefix an leading spaces
cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode
return query, SearchMode.hybrid
@app.post("/api/generate")
async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests acting as an Ollama model
For compatiblity purpuse, the request is not processed by LightRAG,
and will be handled by underlying LLM model.
"""
try:
query = request.prompt
start_time = time.time_ns()
prompt_tokens = estimate_tokens(query)
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.llm_model_func(
query, stream=True, **rag.llm_model_kwargs
)
async def stream_generator():
try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": response,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return
except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}")
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
response_text = await rag.llm_model_func(
query, stream=False, **rag.llm_model_kwargs
)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": str(response_text),
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/chat")
async def chat(raw_request: Request, request: OllamaChatRequest):
"""Process chat completion requests acting as an Ollama model
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
"""
try:
# Get all messages
messages = request.messages
if not messages:
raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query and previous messages as history
query = messages[-1].content
# Convert OllamaMessage objects to dictionaries
conversation_history = [
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
]
# Check for query prefix
cleaned_query, mode = parse_query_mode(query)
start_time = time.time_ns()
prompt_tokens = estimate_tokens(cleaned_query)
param_dict = {
"mode": mode,
"stream": request.stream,
"only_need_context": False,
"conversation_history": conversation_history,
"top_k": args.top_k,
}
if args.history_turns is not None:
param_dict["history_turns"] = args.history_turns
query_param = QueryParam(**param_dict)
if request.stream:
from fastapi.responses import StreamingResponse
# Determine if the request is prefix with "/bypass"
if mode == SearchMode.bypass:
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
response = await rag.llm_model_func(
cleaned_query,
stream=True,
history_messages=conversation_history,
**rag.llm_model_kwargs,
)
else:
response = await rag.aquery( # Need await to get async generator
cleaned_query, param=query_param
)
async def stream_generator():
first_chunk_time = None
last_chunk_time = None
total_response = ""
try:
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": response,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
try:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logging.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": f"\n\nError: {error_msg}",
"images": None,
},
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Send final message to close the stream
final_data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
if last_chunk_time is not None:
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except Exception as e:
error_msg = f"Error in stream_generator: {str(e)}"
logging.error(error_msg)
# Send error message to client
error_data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"error": {"code": "STREAM_ERROR", "message": error_msg},
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Ensure sending end marker
final_data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
match_result = re.search(
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
)
if match_result or mode == SearchMode.bypass:
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
response_text = await rag.llm_model_func(
cleaned_query,
stream=False,
history_messages=conversation_history,
**rag.llm_model_kwargs,
)
else:
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": str(response_text),
"images": None,
},
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
# Add Ollama API routes
ollama_api = OllamaAPI(rag)
app.include_router(ollama_api.router, prefix="/api")
@app.get("/documents", dependencies=[Depends(optional_api_key)])
async def documents():
@@ -1945,10 +1418,10 @@ def create_app(args):
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
"kv_storage": ollama_server_infos.KV_STORAGE,
"doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE,
"graph_storage": ollama_server_infos.GRAPH_STORAGE,
"vector_storage": ollama_server_infos.VECTOR_STORAGE,
"kv_storage": rag_storage_config.KV_STORAGE,
"doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
"graph_storage": rag_storage_config.GRAPH_STORAGE,
"vector_storage": rag_storage_config.VECTOR_STORAGE,
},
}