Migrate Ollama API to lightrag_server.py

This commit is contained in:
yangdx
2025-01-19 04:44:30 +08:00
parent 11f32555b1
commit 8ea179a98b
2 changed files with 438 additions and 63 deletions

View File

@@ -25,9 +25,9 @@ EMBEDDING_BINDING_HOST=http://host.docker.internal:11434
EMBEDDING_MODEL=bge-m3:latest
# Lollms example
EMBEDDING_BINDING=lollms
EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
EMBEDDING_MODEL=bge-m3:latest
# EMBEDDING_BINDING=lollms
# EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
# EMBEDDING_MODEL=bge-m3:latest
# RAG Configuration
MAX_ASYNC=4

View File

@@ -1,7 +1,11 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request
from pydantic import BaseModel
import logging
import argparse
import json
import time
import re
from typing import List, Dict, Any, Optional, Union
from lightrag import LightRAG, QueryParam
from lightrag.llm import lollms_model_complete, lollms_embed
from lightrag.llm import ollama_model_complete, ollama_embed
@@ -10,7 +14,6 @@ from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from typing import Optional, List, Union, Any
from enum import Enum
from pathlib import Path
import shutil
@@ -28,16 +31,41 @@ import pipmaster as pm
from dotenv import load_dotenv
load_dotenv()
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text
Chinese characters: approximately 1.5 tokens per character
English characters: approximately 0.25 tokens per character
"""
# Use regex to match Chinese and non-Chinese characters separately
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
# Calculate estimated token count
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
return int(tokens)
# Constants for emulated Ollama model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = "latest"
LIGHTRAG_MODEL = "lightrag:latest"
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": "http://localhost:11434",
"lollms": "http://localhost:9600",
"azure_openai": "https://api.openai.com/v1",
"openai": "https://api.openai.com/v1",
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
}
return default_hosts.get(
binding_type, "http://localhost:11434"
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
) # fallback to ollama if unknown
@@ -214,9 +242,7 @@ def parse_args() -> argparse.Namespace:
Returns:
argparse.Namespace: Parsed arguments
"""
# Load environment variables from .env file
load_dotenv()
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
@@ -409,6 +435,53 @@ class SearchMode(str, Enum):
local = "local"
global_ = "global"
hybrid = "hybrid"
mix = "mix"
class OllamaMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel):
model: str = LIGHTRAG_MODEL
messages: List[OllamaMessage]
stream: bool = True # Default to streaming mode
options: Optional[Dict[str, Any]] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaVersionResponse(BaseModel):
version: str
class OllamaModelDetails(BaseModel):
parent_model: str
format: str
family: str
families: List[str]
parameter_size: str
quantization_level: str
class OllamaModel(BaseModel):
name: str
model: str
size: int
digest: str
modified_at: str
details: OllamaModelDetails
class OllamaTagResponse(BaseModel):
models: List[OllamaModel]
class QueryRequest(BaseModel):
@@ -514,50 +587,107 @@ def create_app(args):
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
async def openai_alike_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=os.getenv("OPENAI_API_KEY"),
**kwargs,
)
async def azure_openai_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await azure_openai_complete_if_cache(
args.llm_model,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
**kwargs,
)
# Initialize RAG
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=lollms_model_complete
if args.llm_binding == "lollms"
else ollama_model_complete
if args.llm_binding == "ollama"
else azure_openai_complete_if_cache
if args.llm_binding == "azure_openai"
else openai_complete_if_cache,
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={
"host": args.llm_binding_host,
"timeout": args.timeout,
"options": {"num_ctx": args.max_tokens},
},
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.llm_binding in ["lollms", "ollama"] :
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=lollms_model_complete
if args.llm_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.llm_binding == "ollama"
else azure_openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
)
if args.llm_binding == "azure_openai"
else openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
else ollama_model_complete,
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={
"host": args.llm_binding_host,
"timeout": args.timeout,
"options": {"num_ctx": args.max_tokens},
},
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.embedding_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.embedding_binding == "ollama"
else azure_openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
)
if args.embedding_binding == "azure_openai"
else openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
),
),
),
)
)
else :
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=azure_openai_model_complete
if args.llm_binding == "azure_openai"
else openai_alike_model_complete,
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.embedding_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.embedding_binding == "ollama"
else azure_openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
)
if args.embedding_binding == "azure_openai"
else openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
),
),
)
async def index_file(file_path: Union[str, Path]) -> None:
"""Index all files inside the folder with support for multiple file formats
@@ -592,7 +722,7 @@ def create_app(args):
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from pypdf2 import PdfReader
from PyPDF2 import PdfReader
# PDF handling
reader = PdfReader(str(file_path))
@@ -711,13 +841,21 @@ def create_app(args):
),
)
# If response is a string (e.g. cache hit), return directly
if isinstance(response, str):
return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter
if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
return QueryResponse(response=response)
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@@ -725,7 +863,7 @@ def create_app(args):
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = rag.query(
response = await rag.aquery( # Use aquery instead of query, and add await
request.query,
param=QueryParam(
mode=request.mode,
@@ -734,12 +872,37 @@ def create_app(args):
),
)
async def stream_generator():
async for chunk in response:
yield chunk
from fastapi.responses import StreamingResponse
return stream_generator()
async def stream_generator():
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"
else:
# If it's an async generator, send chunks one by one
try:
async for chunk in response:
if chunk: # Only send non-empty content
yield f"{json.dumps({'response': chunk})}\n"
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"X-Accel-Buffering": "no", # Disable Nginx buffering
},
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@app.post(
@@ -790,7 +953,7 @@ def create_app(args):
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from pypdf2 import PdfReader
from PyPDF2 import PdfReader
from io import BytesIO
# Read PDF from memory
@@ -897,7 +1060,7 @@ def create_app(args):
case ".pdf":
if not pm.is_installed("pypdf2"):
pm.install("pypdf2")
from pypdf2 import PdfReader
from PyPDF2 import PdfReader
from io import BytesIO
pdf_content = await file.read()
@@ -993,6 +1156,218 @@ def create_app(args):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Ollama compatible API endpoints
@app.get("/api/version")
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4")
@app.get("/api/tags")
async def get_tags():
"""Get available models"""
return OllamaTagResponse(
models=[
{
"name": LIGHTRAG_MODEL,
"model": LIGHTRAG_MODEL,
"size": LIGHTRAG_SIZE,
"digest": LIGHTRAG_DIGEST,
"modified_at": LIGHTRAG_CREATED_AT,
"details": {
"parent_model": "",
"format": "gguf",
"family": LIGHTRAG_NAME,
"families": [LIGHTRAG_NAME],
"parameter_size": "13B",
"quantization_level": "Q4_0",
},
}
]
)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode)
"""
mode_map = {
"/local ": SearchMode.local,
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid,
"/mix ": SearchMode.mix,
}
for prefix, mode in mode_map.items():
if query.startswith(prefix):
# After removing prefix an leading spaces
cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode
return query, SearchMode.hybrid
@app.post("/api/chat")
async def chat(raw_request: Request, request: OllamaChatRequest):
"""Handle chat completion requests"""
try:
# Get all messages
messages = request.messages
if not messages:
raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query
query = messages[-1].content
# 解析查询模式
cleaned_query, mode = parse_query_mode(query)
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询
query_param = QueryParam(
mode=mode, stream=request.stream, only_need_context=False
)
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.aquery( # Need await to get async generator
cleaned_query, param=query_param
)
async def stream_generator():
try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": response,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return # Ensure the generator ends immediately after sending the completion marker
except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}")
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": str(response_text),
"images": None,
},
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""