Migrate Ollama API to lightrag_server.py
This commit is contained in:
@@ -25,9 +25,9 @@ EMBEDDING_BINDING_HOST=http://host.docker.internal:11434
|
|||||||
EMBEDDING_MODEL=bge-m3:latest
|
EMBEDDING_MODEL=bge-m3:latest
|
||||||
|
|
||||||
# Lollms example
|
# Lollms example
|
||||||
EMBEDDING_BINDING=lollms
|
# EMBEDDING_BINDING=lollms
|
||||||
EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
|
# EMBEDDING_BINDING_HOST=http://host.docker.internal:9600
|
||||||
EMBEDDING_MODEL=bge-m3:latest
|
# EMBEDDING_MODEL=bge-m3:latest
|
||||||
|
|
||||||
# RAG Configuration
|
# RAG Configuration
|
||||||
MAX_ASYNC=4
|
MAX_ASYNC=4
|
||||||
|
@@ -1,7 +1,11 @@
|
|||||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Any, Optional, Union
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import lollms_model_complete, lollms_embed
|
from lightrag.llm import lollms_model_complete, lollms_embed
|
||||||
from lightrag.llm import ollama_model_complete, ollama_embed
|
from lightrag.llm import ollama_model_complete, ollama_embed
|
||||||
@@ -10,7 +14,6 @@ from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
|
|||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
|
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from typing import Optional, List, Union, Any
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
@@ -28,16 +31,41 @@ import pipmaster as pm
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
def estimate_tokens(text: str) -> int:
|
||||||
|
"""Estimate the number of tokens in text
|
||||||
|
Chinese characters: approximately 1.5 tokens per character
|
||||||
|
English characters: approximately 0.25 tokens per character
|
||||||
|
"""
|
||||||
|
# Use regex to match Chinese and non-Chinese characters separately
|
||||||
|
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
|
||||||
|
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
|
||||||
|
|
||||||
|
# Calculate estimated token count
|
||||||
|
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
||||||
|
|
||||||
|
return int(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
# Constants for emulated Ollama model information
|
||||||
|
LIGHTRAG_NAME = "lightrag"
|
||||||
|
LIGHTRAG_TAG = "latest"
|
||||||
|
LIGHTRAG_MODEL = "lightrag:latest"
|
||||||
|
LIGHTRAG_SIZE = 7365960935 # it's a dummy value
|
||||||
|
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
|
||||||
|
LIGHTRAG_DIGEST = "sha256:lightrag"
|
||||||
|
|
||||||
|
|
||||||
def get_default_host(binding_type: str) -> str:
|
def get_default_host(binding_type: str) -> str:
|
||||||
default_hosts = {
|
default_hosts = {
|
||||||
"ollama": "http://localhost:11434",
|
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
|
||||||
"lollms": "http://localhost:9600",
|
"lollms": os.getenv("LLM_BINDING_HOST", "http://localhost:9600"),
|
||||||
"azure_openai": "https://api.openai.com/v1",
|
"azure_openai": os.getenv("AZURE_OPENAI_ENDPOINT", "https://api.openai.com/v1"),
|
||||||
"openai": "https://api.openai.com/v1",
|
"openai": os.getenv("LLM_BINDING_HOST", "https://api.openai.com/v1"),
|
||||||
}
|
}
|
||||||
return default_hosts.get(
|
return default_hosts.get(
|
||||||
binding_type, "http://localhost:11434"
|
binding_type, os.getenv("LLM_BINDING_HOST", "http://localhost:11434")
|
||||||
) # fallback to ollama if unknown
|
) # fallback to ollama if unknown
|
||||||
|
|
||||||
|
|
||||||
@@ -214,8 +242,6 @@ def parse_args() -> argparse.Namespace:
|
|||||||
Returns:
|
Returns:
|
||||||
argparse.Namespace: Parsed arguments
|
argparse.Namespace: Parsed arguments
|
||||||
"""
|
"""
|
||||||
# Load environment variables from .env file
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="LightRAG FastAPI Server with separate working and input directories"
|
description="LightRAG FastAPI Server with separate working and input directories"
|
||||||
@@ -409,6 +435,53 @@ class SearchMode(str, Enum):
|
|||||||
local = "local"
|
local = "local"
|
||||||
global_ = "global"
|
global_ = "global"
|
||||||
hybrid = "hybrid"
|
hybrid = "hybrid"
|
||||||
|
mix = "mix"
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
images: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaChatRequest(BaseModel):
|
||||||
|
model: str = LIGHTRAG_MODEL
|
||||||
|
messages: List[OllamaMessage]
|
||||||
|
stream: bool = True # Default to streaming mode
|
||||||
|
options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaChatResponse(BaseModel):
|
||||||
|
model: str
|
||||||
|
created_at: str
|
||||||
|
message: OllamaMessage
|
||||||
|
done: bool
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaVersionResponse(BaseModel):
|
||||||
|
version: str
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaModelDetails(BaseModel):
|
||||||
|
parent_model: str
|
||||||
|
format: str
|
||||||
|
family: str
|
||||||
|
families: List[str]
|
||||||
|
parameter_size: str
|
||||||
|
quantization_level: str
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaModel(BaseModel):
|
||||||
|
name: str
|
||||||
|
model: str
|
||||||
|
size: int
|
||||||
|
digest: str
|
||||||
|
modified_at: str
|
||||||
|
details: OllamaModelDetails
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaTagResponse(BaseModel):
|
||||||
|
models: List[OllamaModel]
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
@@ -514,50 +587,107 @@ def create_app(args):
|
|||||||
# Initialize document manager
|
# Initialize document manager
|
||||||
doc_manager = DocumentManager(args.input_dir)
|
doc_manager = DocumentManager(args.input_dir)
|
||||||
|
|
||||||
|
|
||||||
|
async def openai_alike_model_complete(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
args.llm_model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
base_url=args.llm_binding_host,
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def azure_openai_model_complete(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await azure_openai_complete_if_cache(
|
||||||
|
args.llm_model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
base_url=args.llm_binding_host,
|
||||||
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||||
|
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize RAG
|
# Initialize RAG
|
||||||
rag = LightRAG(
|
if args.llm_binding in ["lollms", "ollama"] :
|
||||||
working_dir=args.working_dir,
|
rag = LightRAG(
|
||||||
llm_model_func=lollms_model_complete
|
working_dir=args.working_dir,
|
||||||
if args.llm_binding == "lollms"
|
llm_model_func=lollms_model_complete
|
||||||
else ollama_model_complete
|
|
||||||
if args.llm_binding == "ollama"
|
|
||||||
else azure_openai_complete_if_cache
|
|
||||||
if args.llm_binding == "azure_openai"
|
|
||||||
else openai_complete_if_cache,
|
|
||||||
llm_model_name=args.llm_model,
|
|
||||||
llm_model_max_async=args.max_async,
|
|
||||||
llm_model_max_token_size=args.max_tokens,
|
|
||||||
llm_model_kwargs={
|
|
||||||
"host": args.llm_binding_host,
|
|
||||||
"timeout": args.timeout,
|
|
||||||
"options": {"num_ctx": args.max_tokens},
|
|
||||||
},
|
|
||||||
embedding_func=EmbeddingFunc(
|
|
||||||
embedding_dim=args.embedding_dim,
|
|
||||||
max_token_size=args.max_embed_tokens,
|
|
||||||
func=lambda texts: lollms_embed(
|
|
||||||
texts,
|
|
||||||
embed_model=args.embedding_model,
|
|
||||||
host=args.embedding_binding_host,
|
|
||||||
)
|
|
||||||
if args.llm_binding == "lollms"
|
if args.llm_binding == "lollms"
|
||||||
else ollama_embed(
|
else ollama_model_complete,
|
||||||
texts,
|
llm_model_name=args.llm_model,
|
||||||
embed_model=args.embedding_model,
|
llm_model_max_async=args.max_async,
|
||||||
host=args.embedding_binding_host,
|
llm_model_max_token_size=args.max_tokens,
|
||||||
)
|
llm_model_kwargs={
|
||||||
if args.llm_binding == "ollama"
|
"host": args.llm_binding_host,
|
||||||
else azure_openai_embedding(
|
"timeout": args.timeout,
|
||||||
texts,
|
"options": {"num_ctx": args.max_tokens},
|
||||||
model=args.embedding_model, # no host is used for openai
|
},
|
||||||
)
|
embedding_func=EmbeddingFunc(
|
||||||
if args.llm_binding == "azure_openai"
|
embedding_dim=args.embedding_dim,
|
||||||
else openai_embedding(
|
max_token_size=args.max_embed_tokens,
|
||||||
texts,
|
func=lambda texts: lollms_embed(
|
||||||
model=args.embedding_model, # no host is used for openai
|
texts,
|
||||||
|
embed_model=args.embedding_model,
|
||||||
|
host=args.embedding_binding_host,
|
||||||
|
)
|
||||||
|
if args.embedding_binding == "lollms"
|
||||||
|
else ollama_embed(
|
||||||
|
texts,
|
||||||
|
embed_model=args.embedding_model,
|
||||||
|
host=args.embedding_binding_host,
|
||||||
|
)
|
||||||
|
if args.embedding_binding == "ollama"
|
||||||
|
else azure_openai_embedding(
|
||||||
|
texts,
|
||||||
|
model=args.embedding_model, # no host is used for openai
|
||||||
|
)
|
||||||
|
if args.embedding_binding == "azure_openai"
|
||||||
|
else openai_embedding(
|
||||||
|
texts,
|
||||||
|
model=args.embedding_model, # no host is used for openai
|
||||||
|
),
|
||||||
),
|
),
|
||||||
),
|
)
|
||||||
)
|
else :
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=args.working_dir,
|
||||||
|
llm_model_func=azure_openai_model_complete
|
||||||
|
if args.llm_binding == "azure_openai"
|
||||||
|
else openai_alike_model_complete,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=args.embedding_dim,
|
||||||
|
max_token_size=args.max_embed_tokens,
|
||||||
|
func=lambda texts: lollms_embed(
|
||||||
|
texts,
|
||||||
|
embed_model=args.embedding_model,
|
||||||
|
host=args.embedding_binding_host,
|
||||||
|
)
|
||||||
|
if args.embedding_binding == "lollms"
|
||||||
|
else ollama_embed(
|
||||||
|
texts,
|
||||||
|
embed_model=args.embedding_model,
|
||||||
|
host=args.embedding_binding_host,
|
||||||
|
)
|
||||||
|
if args.embedding_binding == "ollama"
|
||||||
|
else azure_openai_embedding(
|
||||||
|
texts,
|
||||||
|
model=args.embedding_model, # no host is used for openai
|
||||||
|
)
|
||||||
|
if args.embedding_binding == "azure_openai"
|
||||||
|
else openai_embedding(
|
||||||
|
texts,
|
||||||
|
model=args.embedding_model, # no host is used for openai
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
async def index_file(file_path: Union[str, Path]) -> None:
|
async def index_file(file_path: Union[str, Path]) -> None:
|
||||||
"""Index all files inside the folder with support for multiple file formats
|
"""Index all files inside the folder with support for multiple file formats
|
||||||
@@ -592,7 +722,7 @@ def create_app(args):
|
|||||||
case ".pdf":
|
case ".pdf":
|
||||||
if not pm.is_installed("pypdf2"):
|
if not pm.is_installed("pypdf2"):
|
||||||
pm.install("pypdf2")
|
pm.install("pypdf2")
|
||||||
from pypdf2 import PdfReader
|
from PyPDF2 import PdfReader
|
||||||
|
|
||||||
# PDF handling
|
# PDF handling
|
||||||
reader = PdfReader(str(file_path))
|
reader = PdfReader(str(file_path))
|
||||||
@@ -711,13 +841,21 @@ def create_app(args):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If response is a string (e.g. cache hit), return directly
|
||||||
|
if isinstance(response, str):
|
||||||
|
return QueryResponse(response=response)
|
||||||
|
|
||||||
|
# If it's an async generator, decide whether to stream based on stream parameter
|
||||||
if request.stream:
|
if request.stream:
|
||||||
result = ""
|
result = ""
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
result += chunk
|
result += chunk
|
||||||
return QueryResponse(response=result)
|
return QueryResponse(response=result)
|
||||||
else:
|
else:
|
||||||
return QueryResponse(response=response)
|
result = ""
|
||||||
|
async for chunk in response:
|
||||||
|
result += chunk
|
||||||
|
return QueryResponse(response=result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -725,7 +863,7 @@ def create_app(args):
|
|||||||
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||||
async def query_text_stream(request: QueryRequest):
|
async def query_text_stream(request: QueryRequest):
|
||||||
try:
|
try:
|
||||||
response = rag.query(
|
response = await rag.aquery( # Use aquery instead of query, and add await
|
||||||
request.query,
|
request.query,
|
||||||
param=QueryParam(
|
param=QueryParam(
|
||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
@@ -734,12 +872,37 @@ def create_app(args):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_generator():
|
from fastapi.responses import StreamingResponse
|
||||||
async for chunk in response:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return stream_generator()
|
async def stream_generator():
|
||||||
|
if isinstance(response, str):
|
||||||
|
# If it's a string, send it all at once
|
||||||
|
yield f"{json.dumps({'response': response})}\n"
|
||||||
|
else:
|
||||||
|
# If it's an async generator, send chunks one by one
|
||||||
|
try:
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk: # Only send non-empty content
|
||||||
|
yield f"{json.dumps({'response': chunk})}\n"
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Streaming error: {str(e)}")
|
||||||
|
yield f"{json.dumps({'error': str(e)})}\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_generator(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "application/x-ndjson",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type",
|
||||||
|
"X-Accel-Buffering": "no", # Disable Nginx buffering
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
@@ -790,7 +953,7 @@ def create_app(args):
|
|||||||
case ".pdf":
|
case ".pdf":
|
||||||
if not pm.is_installed("pypdf2"):
|
if not pm.is_installed("pypdf2"):
|
||||||
pm.install("pypdf2")
|
pm.install("pypdf2")
|
||||||
from pypdf2 import PdfReader
|
from PyPDF2 import PdfReader
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
# Read PDF from memory
|
# Read PDF from memory
|
||||||
@@ -897,7 +1060,7 @@ def create_app(args):
|
|||||||
case ".pdf":
|
case ".pdf":
|
||||||
if not pm.is_installed("pypdf2"):
|
if not pm.is_installed("pypdf2"):
|
||||||
pm.install("pypdf2")
|
pm.install("pypdf2")
|
||||||
from pypdf2 import PdfReader
|
from PyPDF2 import PdfReader
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
pdf_content = await file.read()
|
pdf_content = await file.read()
|
||||||
@@ -993,6 +1156,218 @@ def create_app(args):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Ollama compatible API endpoints
|
||||||
|
@app.get("/api/version")
|
||||||
|
async def get_version():
|
||||||
|
"""Get Ollama version information"""
|
||||||
|
return OllamaVersionResponse(version="0.5.4")
|
||||||
|
|
||||||
|
@app.get("/api/tags")
|
||||||
|
async def get_tags():
|
||||||
|
"""Get available models"""
|
||||||
|
return OllamaTagResponse(
|
||||||
|
models=[
|
||||||
|
{
|
||||||
|
"name": LIGHTRAG_MODEL,
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"size": LIGHTRAG_SIZE,
|
||||||
|
"digest": LIGHTRAG_DIGEST,
|
||||||
|
"modified_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"details": {
|
||||||
|
"parent_model": "",
|
||||||
|
"format": "gguf",
|
||||||
|
"family": LIGHTRAG_NAME,
|
||||||
|
"families": [LIGHTRAG_NAME],
|
||||||
|
"parameter_size": "13B",
|
||||||
|
"quantization_level": "Q4_0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
|
||||||
|
"""Parse query prefix to determine search mode
|
||||||
|
Returns tuple of (cleaned_query, search_mode)
|
||||||
|
"""
|
||||||
|
mode_map = {
|
||||||
|
"/local ": SearchMode.local,
|
||||||
|
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
|
||||||
|
"/naive ": SearchMode.naive,
|
||||||
|
"/hybrid ": SearchMode.hybrid,
|
||||||
|
"/mix ": SearchMode.mix,
|
||||||
|
}
|
||||||
|
|
||||||
|
for prefix, mode in mode_map.items():
|
||||||
|
if query.startswith(prefix):
|
||||||
|
# After removing prefix an leading spaces
|
||||||
|
cleaned_query = query[len(prefix) :].lstrip()
|
||||||
|
return cleaned_query, mode
|
||||||
|
|
||||||
|
return query, SearchMode.hybrid
|
||||||
|
|
||||||
|
@app.post("/api/chat")
|
||||||
|
async def chat(raw_request: Request, request: OllamaChatRequest):
|
||||||
|
"""Handle chat completion requests"""
|
||||||
|
try:
|
||||||
|
# Get all messages
|
||||||
|
messages = request.messages
|
||||||
|
if not messages:
|
||||||
|
raise HTTPException(status_code=400, detail="No messages provided")
|
||||||
|
|
||||||
|
# Get the last message as query
|
||||||
|
query = messages[-1].content
|
||||||
|
|
||||||
|
# 解析查询模式
|
||||||
|
cleaned_query, mode = parse_query_mode(query)
|
||||||
|
|
||||||
|
# 开始计时
|
||||||
|
start_time = time.time_ns()
|
||||||
|
|
||||||
|
# 计算输入token数量
|
||||||
|
prompt_tokens = estimate_tokens(cleaned_query)
|
||||||
|
|
||||||
|
# 调用RAG进行查询
|
||||||
|
query_param = QueryParam(
|
||||||
|
mode=mode, stream=request.stream, only_need_context=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
response = await rag.aquery( # Need await to get async generator
|
||||||
|
cleaned_query, param=query_param
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stream_generator():
|
||||||
|
try:
|
||||||
|
first_chunk_time = None
|
||||||
|
last_chunk_time = None
|
||||||
|
total_response = ""
|
||||||
|
|
||||||
|
# Ensure response is an async generator
|
||||||
|
if isinstance(response, str):
|
||||||
|
# If it's a string, send in two parts
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
last_chunk_time = first_chunk_time
|
||||||
|
total_response = response
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response,
|
||||||
|
"images": None,
|
||||||
|
},
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
else:
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk:
|
||||||
|
if first_chunk_time is None:
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
total_response += chunk
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": chunk,
|
||||||
|
"images": None,
|
||||||
|
},
|
||||||
|
"done": False,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(total_response)
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
yield f"{json.dumps(data, ensure_ascii=False)}\n"
|
||||||
|
return # Ensure the generator ends immediately after sending the completion marker
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in stream_generator: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_generator(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Content-Type": "application/x-ndjson",
|
||||||
|
"Access-Control-Allow-Origin": "*",
|
||||||
|
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Content-Type",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
first_chunk_time = time.time_ns()
|
||||||
|
response_text = await rag.aquery(cleaned_query, param=query_param)
|
||||||
|
last_chunk_time = time.time_ns()
|
||||||
|
|
||||||
|
if not response_text:
|
||||||
|
response_text = "No response generated"
|
||||||
|
|
||||||
|
completion_tokens = estimate_tokens(str(response_text))
|
||||||
|
total_time = last_chunk_time - start_time
|
||||||
|
prompt_eval_time = first_chunk_time - start_time
|
||||||
|
eval_time = last_chunk_time - first_chunk_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": LIGHTRAG_MODEL,
|
||||||
|
"created_at": LIGHTRAG_CREATED_AT,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": str(response_text),
|
||||||
|
"images": None,
|
||||||
|
},
|
||||||
|
"done": True,
|
||||||
|
"total_duration": total_time,
|
||||||
|
"load_duration": 0,
|
||||||
|
"prompt_eval_count": prompt_tokens,
|
||||||
|
"prompt_eval_duration": prompt_eval_time,
|
||||||
|
"eval_count": completion_tokens,
|
||||||
|
"eval_duration": eval_time,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
trace_exception(e)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||||
async def get_status():
|
async def get_status():
|
||||||
"""Get current system status"""
|
"""Get current system status"""
|
||||||
|
Reference in New Issue
Block a user