Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo
This commit is contained in:
140
examples/lightrag_api_open_webui_demo.py
Normal file
140
examples/lightrag_api_open_webui_demo.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
import inspect
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
import os
|
||||
import logging
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import ollama_model_complete, ollama_embed
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
|
||||
import nest_asyncio
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=ollama_model_complete,
|
||||
llm_model_name="qwen2.5:latest",
|
||||
llm_model_max_async=4,
|
||||
llm_model_max_token_size=32768,
|
||||
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: ollama_embed(
|
||||
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
|
||||
|
||||
|
||||
# Data models
|
||||
MODEL_NAME = "LightRAG:latest"
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: str
|
||||
|
||||
|
||||
class OpenWebUIRequest(BaseModel):
|
||||
stream: Optional[bool] = None
|
||||
model: Optional[str] = None
|
||||
messages: list[Message]
|
||||
|
||||
|
||||
# API routes
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def index():
|
||||
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
|
||||
|
||||
|
||||
@app.get("/ollama/api/version")
|
||||
async def ollama_version():
|
||||
return {"version": "0.4.7"}
|
||||
|
||||
|
||||
@app.get("/ollama/api/tags")
|
||||
async def ollama_tags():
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"name": MODEL_NAME,
|
||||
"model": MODEL_NAME,
|
||||
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
|
||||
"size": 4683087332,
|
||||
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": "qwen2",
|
||||
"families": ["qwen2"],
|
||||
"parameter_size": "7.6B",
|
||||
"quantization_level": "Q4_K_M",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/ollama/api/chat")
|
||||
async def ollama_chat(request: OpenWebUIRequest):
|
||||
resp = rag.query(
|
||||
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
|
||||
)
|
||||
if inspect.isasyncgen(resp):
|
||||
|
||||
async def ollama_resp(chunks):
|
||||
async for chunk in chunks:
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
"model": MODEL_NAME,
|
||||
"created_at": datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
),
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": chunk,
|
||||
},
|
||||
"done": False,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
).encode("utf-8")
|
||||
+ b"\n"
|
||||
) # the b"\n" is important
|
||||
|
||||
return StreamingResponse(ollama_resp(resp), media_type="application/json")
|
||||
else:
|
||||
return resp
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8020)
|
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import inspect
|
||||
import logging
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import ollama_model_complete, ollama_embedding
|
||||
@@ -49,3 +51,20 @@ print(
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
||||
|
||||
# stream response
|
||||
resp = rag.query(
|
||||
"What are the top themes in this story?",
|
||||
param=QueryParam(mode="hybrid", stream=True),
|
||||
)
|
||||
|
||||
|
||||
async def print_stream(stream):
|
||||
async for chunk in stream:
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
|
||||
if inspect.isasyncgen(resp):
|
||||
asyncio.run(print_stream(resp))
|
||||
else:
|
||||
print(resp)
|
||||
|
@@ -19,6 +19,7 @@ class QueryParam:
|
||||
only_need_context: bool = False
|
||||
only_need_prompt: bool = False
|
||||
response_type: str = "Multiple Paragraphs"
|
||||
stream: bool = False
|
||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
||||
top_k: int = 60
|
||||
# Number of document chunks to retrieve.
|
||||
|
@@ -27,7 +27,7 @@ from tenacity import (
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Callable, Any
|
||||
from typing import List, Dict, Callable, Any, Union
|
||||
from .base import BaseKVStorage
|
||||
from .utils import (
|
||||
compute_args_hash,
|
||||
@@ -37,6 +37,13 @@ from .utils import (
|
||||
get_best_cached_response,
|
||||
)
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -454,7 +461,8 @@ async def ollama_model_if_cache(
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
stream = True if kwargs.get("stream") else False
|
||||
kwargs.pop("max_tokens", None)
|
||||
# kwargs.pop("response_format", None) # allow json
|
||||
host = kwargs.pop("host", None)
|
||||
@@ -494,28 +502,39 @@ async def ollama_model_if_cache(
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
if stream:
|
||||
""" cannot cache stream response """
|
||||
|
||||
result = response["message"]["content"]
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
yield chunk["message"]["content"]
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": result,
|
||||
"model": model,
|
||||
"embedding": quantized.tobytes().hex()
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_shape": quantized.shape
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||
"original_prompt": prompt,
|
||||
return inner()
|
||||
else:
|
||||
result = response["message"]["content"]
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": result,
|
||||
"model": model,
|
||||
"embedding": quantized.tobytes().hex()
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_shape": quantized.shape
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_min": min_val
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_max": max_val
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"original_prompt": prompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
return result
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -785,7 +804,7 @@ async def hf_model_complete(
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["format"] = "json"
|
||||
|
@@ -534,8 +534,9 @@ async def kg_query(
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
if len(response) > len(sys_prompt):
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
|
Reference in New Issue
Block a user