Split the Ollama API implementation to a separated file

This commit is contained in:
yangdx
2025-02-05 22:15:14 +08:00
parent 2bbf451fa0
commit f703334ce4
2 changed files with 594 additions and 567 deletions

View File

@@ -4,28 +4,20 @@ from fastapi import (
File, File,
UploadFile, UploadFile,
Form, Form,
Request,
BackgroundTasks, BackgroundTasks,
) )
# Backend (Python)
# Add this to store progress globally
from typing import Dict
import threading import threading
import asyncio
import json
import os import os
import json
import re
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import logging import logging
import argparse import argparse
import time from typing import List, Any, Optional, Union, Dict
import re from pydantic import BaseModel
from typing import List, Any, Optional, Union
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@@ -34,20 +26,30 @@ import aiofiles
from ascii_colors import trace_exception, ASCIIColors from ascii_colors import trace_exception, ASCIIColors
import sys import sys
import configparser import configparser
from fastapi import Depends, Security from fastapi import Depends, Security
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm import pipmaster as pm
from dotenv import load_dotenv from dotenv import load_dotenv
from .ollama_api import (
OllamaAPI,
)
from .ollama_api import ollama_server_infos
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
class RAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
# Initialize rag storage config
rag_storage_config = RAGStorageConfig()
# Global progress tracker # Global progress tracker
scan_progress: Dict = { scan_progress: Dict = {
"is_scanning": False, "is_scanning": False,
@@ -76,24 +78,6 @@ def estimate_tokens(text: str) -> int:
return int(tokens) return int(tokens)
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
# Add infos
ollama_server_infos = OllamaServerInfos()
# read config.ini # read config.ini
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
@@ -101,8 +85,8 @@ config.read("config.ini", "utf-8")
redis_uri = config.get("redis", "uri", fallback=None) redis_uri = config.get("redis", "uri", fallback=None)
if redis_uri: if redis_uri:
os.environ["REDIS_URI"] = redis_uri os.environ["REDIS_URI"] = redis_uri
ollama_server_infos.KV_STORAGE = "RedisKVStorage" rag_storage_config.KV_STORAGE = "RedisKVStorage"
ollama_server_infos.DOC_STATUS_STORAGE = "RedisKVStorage" rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
# Neo4j config # Neo4j config
neo4j_uri = config.get("neo4j", "uri", fallback=None) neo4j_uri = config.get("neo4j", "uri", fallback=None)
@@ -112,7 +96,7 @@ if neo4j_uri:
os.environ["NEO4J_URI"] = neo4j_uri os.environ["NEO4J_URI"] = neo4j_uri
os.environ["NEO4J_USERNAME"] = neo4j_username os.environ["NEO4J_USERNAME"] = neo4j_username
os.environ["NEO4J_PASSWORD"] = neo4j_password os.environ["NEO4J_PASSWORD"] = neo4j_password
ollama_server_infos.GRAPH_STORAGE = "Neo4JStorage" rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
# Milvus config # Milvus config
milvus_uri = config.get("milvus", "uri", fallback=None) milvus_uri = config.get("milvus", "uri", fallback=None)
@@ -124,7 +108,7 @@ if milvus_uri:
os.environ["MILVUS_USER"] = milvus_user os.environ["MILVUS_USER"] = milvus_user
os.environ["MILVUS_PASSWORD"] = milvus_password os.environ["MILVUS_PASSWORD"] = milvus_password
os.environ["MILVUS_DB_NAME"] = milvus_db_name os.environ["MILVUS_DB_NAME"] = milvus_db_name
ollama_server_infos.VECTOR_STORAGE = "MilvusVectorDBStorge" rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
# MongoDB config # MongoDB config
mongo_uri = config.get("mongodb", "uri", fallback=None) mongo_uri = config.get("mongodb", "uri", fallback=None)
@@ -132,8 +116,8 @@ mongo_database = config.get("mongodb", "LightRAG", fallback=None)
if mongo_uri: if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database os.environ["MONGO_DATABASE"] = mongo_database
ollama_server_infos.KV_STORAGE = "MongoKVStorage" rag_storage_config.KV_STORAGE = "MongoKVStorage"
ollama_server_infos.DOC_STATUS_STORAGE = "MongoKVStorage" rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
def get_default_host(binding_type: str) -> str: def get_default_host(binding_type: str) -> str:
@@ -535,6 +519,7 @@ def parse_args() -> argparse.Namespace:
help="Cosine similarity threshold (default: from env or 0.4)", help="Cosine similarity threshold (default: from env or 0.4)",
) )
# Ollama model name
parser.add_argument( parser.add_argument(
"--simulated-model-name", "--simulated-model-name",
type=str, type=str,
@@ -599,84 +584,13 @@ class DocumentManager:
return any(filename.lower().endswith(ext) for ext in self.supported_extensions) return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models # LightRAG query mode
class SearchMode(str, Enum): class SearchMode(str, Enum):
naive = "naive" naive = "naive"
local = "local" local = "local"
global_ = "global" global_ = "global"
hybrid = "hybrid" hybrid = "hybrid"
mix = "mix" mix = "mix"
bypass = "bypass"
class OllamaMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel):
model: str = ollama_server_infos.LIGHTRAG_MODEL
messages: List[OllamaMessage]
stream: bool = True # Default to streaming mode
options: Optional[Dict[str, Any]] = None
system: Optional[str] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaGenerateRequest(BaseModel):
model: str = ollama_server_infos.LIGHTRAG_MODEL
prompt: str
system: Optional[str] = None
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaGenerateResponse(BaseModel):
model: str
created_at: str
response: str
done: bool
context: Optional[List[int]]
total_duration: Optional[int]
load_duration: Optional[int]
prompt_eval_count: Optional[int]
prompt_eval_duration: Optional[int]
eval_count: Optional[int]
eval_duration: Optional[int]
class OllamaVersionResponse(BaseModel):
version: str
class OllamaModelDetails(BaseModel):
parent_model: str
format: str
family: str
families: List[str]
parameter_size: str
quantization_level: str
class OllamaModel(BaseModel):
name: str
model: str
size: int
digest: str
modified_at: str
details: OllamaModelDetails
class OllamaTagResponse(BaseModel):
models: List[OllamaModel]
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
@@ -920,10 +834,10 @@ def create_app(args):
if args.llm_binding == "lollms" or args.llm_binding == "ollama" if args.llm_binding == "lollms" or args.llm_binding == "ollama"
else {}, else {},
embedding_func=embedding_func, embedding_func=embedding_func,
kv_storage=ollama_server_infos.KV_STORAGE, kv_storage=rag_storage_config.KV_STORAGE,
graph_storage=ollama_server_infos.GRAPH_STORAGE, graph_storage=rag_storage_config.GRAPH_STORAGE,
vector_storage=ollama_server_infos.VECTOR_STORAGE, vector_storage=rag_storage_config.VECTOR_STORAGE,
doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE, doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
@@ -949,10 +863,10 @@ def create_app(args):
llm_model_max_async=args.max_async, llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens, llm_model_max_token_size=args.max_tokens,
embedding_func=embedding_func, embedding_func=embedding_func,
kv_storage=ollama_server_infos.KV_STORAGE, kv_storage=rag_storage_config.KV_STORAGE,
graph_storage=ollama_server_infos.GRAPH_STORAGE, graph_storage=rag_storage_config.GRAPH_STORAGE,
vector_storage=ollama_server_infos.VECTOR_STORAGE, vector_storage=rag_storage_config.VECTOR_STORAGE,
doc_status_storage=ollama_server_infos.DOC_STATUS_STORAGE, doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE,
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold "cosine_better_than_threshold": args.cosine_threshold
}, },
@@ -1475,450 +1389,9 @@ def create_app(args):
async def get_graphs(label: str): async def get_graphs(label: str):
return await rag.get_graps(nodel_label=label, max_depth=100) return await rag.get_graps(nodel_label=label, max_depth=100)
# Ollama compatible API endpoints # Add Ollama API routes
# ------------------------------------------------- ollama_api = OllamaAPI(rag)
@app.get("/api/version") app.include_router(ollama_api.router, prefix="/api")
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4")
@app.get("/api/tags")
async def get_tags():
"""Retrun available models acting as an Ollama server"""
return OllamaTagResponse(
models=[
{
"name": ollama_server_infos.LIGHTRAG_MODEL,
"model": ollama_server_infos.LIGHTRAG_MODEL,
"size": ollama_server_infos.LIGHTRAG_SIZE,
"digest": ollama_server_infos.LIGHTRAG_DIGEST,
"modified_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"details": {
"parent_model": "",
"format": "gguf",
"family": ollama_server_infos.LIGHTRAG_NAME,
"families": [ollama_server_infos.LIGHTRAG_NAME],
"parameter_size": "13B",
"quantization_level": "Q4_0",
},
}
]
)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode)
"""
mode_map = {
"/local ": SearchMode.local,
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid,
"/mix ": SearchMode.mix,
"/bypass ": SearchMode.bypass,
}
for prefix, mode in mode_map.items():
if query.startswith(prefix):
# After removing prefix an leading spaces
cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode
return query, SearchMode.hybrid
@app.post("/api/generate")
async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests acting as an Ollama model
For compatiblity purpuse, the request is not processed by LightRAG,
and will be handled by underlying LLM model.
"""
try:
query = request.prompt
start_time = time.time_ns()
prompt_tokens = estimate_tokens(query)
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.llm_model_func(
query, stream=True, **rag.llm_model_kwargs
)
async def stream_generator():
try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": response,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return
except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}")
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
response_text = await rag.llm_model_func(
query, stream=False, **rag.llm_model_kwargs
)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": str(response_text),
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/chat")
async def chat(raw_request: Request, request: OllamaChatRequest):
"""Process chat completion requests acting as an Ollama model
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
"""
try:
# Get all messages
messages = request.messages
if not messages:
raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query and previous messages as history
query = messages[-1].content
# Convert OllamaMessage objects to dictionaries
conversation_history = [
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
]
# Check for query prefix
cleaned_query, mode = parse_query_mode(query)
start_time = time.time_ns()
prompt_tokens = estimate_tokens(cleaned_query)
param_dict = {
"mode": mode,
"stream": request.stream,
"only_need_context": False,
"conversation_history": conversation_history,
"top_k": args.top_k,
}
if args.history_turns is not None:
param_dict["history_turns"] = args.history_turns
query_param = QueryParam(**param_dict)
if request.stream:
from fastapi.responses import StreamingResponse
# Determine if the request is prefix with "/bypass"
if mode == SearchMode.bypass:
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
response = await rag.llm_model_func(
cleaned_query,
stream=True,
history_messages=conversation_history,
**rag.llm_model_kwargs,
)
else:
response = await rag.aquery( # Need await to get async generator
cleaned_query, param=query_param
)
async def stream_generator():
first_chunk_time = None
last_chunk_time = None
total_response = ""
try:
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": response,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
try:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logging.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": f"\n\nError: {error_msg}",
"images": None,
},
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Send final message to close the stream
final_data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
if last_chunk_time is not None:
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except Exception as e:
error_msg = f"Error in stream_generator: {str(e)}"
logging.error(error_msg)
# Send error message to client
error_data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"error": {"code": "STREAM_ERROR", "message": error_msg},
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Ensure sending end marker
final_data = {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
match_result = re.search(
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
)
if match_result or mode == SearchMode.bypass:
if request.system:
rag.llm_model_kwargs["system_prompt"] = request.system
response_text = await rag.llm_model_func(
cleaned_query,
stream=False,
history_messages=conversation_history,
**rag.llm_model_kwargs,
)
else:
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": ollama_server_infos.LIGHTRAG_MODEL,
"created_at": ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": str(response_text),
"images": None,
},
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/documents", dependencies=[Depends(optional_api_key)]) @app.get("/documents", dependencies=[Depends(optional_api_key)])
async def documents(): async def documents():
@@ -1945,10 +1418,10 @@ def create_app(args):
"embedding_binding_host": args.embedding_binding_host, "embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model, "embedding_model": args.embedding_model,
"max_tokens": args.max_tokens, "max_tokens": args.max_tokens,
"kv_storage": ollama_server_infos.KV_STORAGE, "kv_storage": rag_storage_config.KV_STORAGE,
"doc_status_storage": ollama_server_infos.DOC_STATUS_STORAGE, "doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE,
"graph_storage": ollama_server_infos.GRAPH_STORAGE, "graph_storage": rag_storage_config.GRAPH_STORAGE,
"vector_storage": ollama_server_infos.VECTOR_STORAGE, "vector_storage": rag_storage_config.VECTOR_STORAGE,
}, },
} }

554
lightrag/api/ollama_api.py Normal file
View File

@@ -0,0 +1,554 @@
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import logging
import time
import json
import re
import os
from enum import Enum
from fastapi.responses import StreamingResponse
import asyncio
from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam
class OllamaServerInfos:
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = os.getenv("OLLAMA_EMULATING_MODEL_TAG", "latest")
LIGHTRAG_MODEL = f"{LIGHTRAG_NAME}:{LIGHTRAG_TAG}"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
ollama_server_infos = OllamaServerInfos()
# query mode according to query prefix (bypass is not LightRAG quer mode)
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global"
hybrid = "hybrid"
mix = "mix"
bypass = "bypass"
class OllamaMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel):
model: str
messages: List[OllamaMessage]
stream: bool = True
options: Optional[Dict[str, Any]] = None
system: Optional[str] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaGenerateRequest(BaseModel):
model: str
prompt: str
system: Optional[str] = None
stream: bool = False
options: Optional[Dict[str, Any]] = None
class OllamaGenerateResponse(BaseModel):
model: str
created_at: str
response: str
done: bool
context: Optional[List[int]]
total_duration: Optional[int]
load_duration: Optional[int]
prompt_eval_count: Optional[int]
prompt_eval_duration: Optional[int]
eval_count: Optional[int]
eval_duration: Optional[int]
class OllamaVersionResponse(BaseModel):
version: str
class OllamaModelDetails(BaseModel):
parent_model: str
format: str
family: str
families: List[str]
parameter_size: str
quantization_level: str
class OllamaModel(BaseModel):
name: str
model: str
size: int
digest: str
modified_at: str
details: OllamaModelDetails
class OllamaTagResponse(BaseModel):
models: List[OllamaModel]
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text
Chinese characters: approximately 1.5 tokens per character
English characters: approximately 0.25 tokens per character
"""
# Use regex to match Chinese and non-Chinese characters separately
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
# Calculate estimated token count
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
return int(tokens)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode)
"""
mode_map = {
"/local ": SearchMode.local,
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid,
"/mix ": SearchMode.mix,
"/bypass ": SearchMode.bypass,
}
for prefix, mode in mode_map.items():
if query.startswith(prefix):
# After removing prefix an leading spaces
cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode
return query, SearchMode.hybrid
class OllamaAPI:
def __init__(self, rag: LightRAG):
self.rag = rag
self.ollama_server_infos = ollama_server_infos
self.router = APIRouter()
self.setup_routes()
def setup_routes(self):
@self.router.get("/version")
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4")
@self.router.get("/tags")
async def get_tags():
"""Return available models acting as an Ollama server"""
return OllamaTagResponse(
models=[
{
"name": self.ollama_server_infos.LIGHTRAG_MODEL,
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"size": self.ollama_server_infos.LIGHTRAG_SIZE,
"digest": self.ollama_server_infos.LIGHTRAG_DIGEST,
"modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"details": {
"parent_model": "",
"format": "gguf",
"family": self.ollama_server_infos.LIGHTRAG_NAME,
"families": [self.ollama_server_infos.LIGHTRAG_NAME],
"parameter_size": "13B",
"quantization_level": "Q4_0",
},
}
]
)
@self.router.post("/generate")
async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests acting as an Ollama model
For compatibility purpose, the request is not processed by LightRAG,
and will be handled by underlying LLM model.
"""
try:
query = request.prompt
start_time = time.time_ns()
prompt_tokens = estimate_tokens(query)
if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system
if request.stream:
response = await self.rag.llm_model_func(
query, stream=True, **self.rag.llm_model_kwargs
)
async def stream_generator():
try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": response,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": chunk,
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return
except Exception as e:
trace_exception(e)
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
response_text = await self.rag.llm_model_func(
query, stream=False, **self.rag.llm_model_kwargs
)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"response": str(response_text),
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@self.router.post("/chat")
async def chat(raw_request: Request, request: OllamaChatRequest):
"""Process chat completion requests acting as an Ollama model
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
"""
try:
# Get all messages
messages = request.messages
if not messages:
raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query and previous messages as history
query = messages[-1].content
# Convert OllamaMessage objects to dictionaries
conversation_history = [
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
]
# Check for query prefix
cleaned_query, mode = parse_query_mode(query)
start_time = time.time_ns()
prompt_tokens = estimate_tokens(cleaned_query)
param_dict = {
"mode": mode,
"stream": request.stream,
"only_need_context": False,
"conversation_history": conversation_history,
"top_k": self.rag.args.top_k if hasattr(self.rag, 'args') else 50,
}
if hasattr(self.rag, 'args') and self.rag.args.history_turns is not None:
param_dict["history_turns"] = self.rag.args.history_turns
query_param = QueryParam(**param_dict)
if request.stream:
# Determine if the request is prefix with "/bypass"
if mode == SearchMode.bypass:
if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system
response = await self.rag.llm_model_func(
cleaned_query,
stream=True,
history_messages=conversation_history,
**self.rag.llm_model_kwargs,
)
else:
response = await self.rag.aquery(
cleaned_query, param=query_param
)
async def stream_generator():
first_chunk_time = None
last_chunk_time = None
total_response = ""
try:
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": response,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
try:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except (asyncio.CancelledError, Exception) as e:
error_msg = str(e)
if isinstance(e, asyncio.CancelledError):
error_msg = "Stream was cancelled by server"
else:
error_msg = f"Provider error: {error_msg}"
logging.error(f"Stream error: {error_msg}")
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": f"\n\nError: {error_msg}",
"images": None,
},
"done": False,
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Send final message to close the stream
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
if last_chunk_time is not None:
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
except Exception as e:
error_msg = f"Error in stream_generator: {str(e)}"
logging.error(error_msg)
# Send error message to client
error_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"error": {"code": "STREAM_ERROR", "message": error_msg},
}
yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
# Ensure sending end marker
final_data = {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"done": True,
}
yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
return
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
match_result = re.search(
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
)
if match_result or mode == SearchMode.bypass:
if request.system:
self.rag.llm_model_kwargs["system_prompt"] = request.system
response_text = await self.rag.llm_model_func(
cleaned_query,
stream=False,
history_messages=conversation_history,
**self.rag.llm_model_kwargs,
)
else:
response_text = await self.rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": self.ollama_server_infos.LIGHTRAG_MODEL,
"created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": str(response_text),
"images": None,
},
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))