Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo
This commit is contained in:
140
examples/lightrag_api_open_webui_demo.py
Normal file
140
examples/lightrag_api_open_webui_demo.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import ollama_model_complete, ollama_embed
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=ollama_model_complete,
|
||||||
|
llm_model_name="qwen2.5:latest",
|
||||||
|
llm_model_max_async=4,
|
||||||
|
llm_model_max_token_size=32768,
|
||||||
|
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=1024,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=lambda texts: ollama_embed(
|
||||||
|
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
# Apply nest_asyncio to solve event loop issues
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
|
||||||
|
|
||||||
|
|
||||||
|
# Data models
|
||||||
|
MODEL_NAME = "LightRAG:latest"
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class OpenWebUIRequest(BaseModel):
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
messages: list[Message]
|
||||||
|
|
||||||
|
|
||||||
|
# API routes
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def index():
|
||||||
|
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/ollama/api/version")
|
||||||
|
async def ollama_version():
|
||||||
|
return {"version": "0.4.7"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/ollama/api/tags")
|
||||||
|
async def ollama_tags():
|
||||||
|
return {
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": MODEL_NAME,
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
|
||||||
|
"size": 4683087332,
|
||||||
|
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
|
||||||
|
"details": {
|
||||||
|
"parent_model": "",
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "qwen2",
|
||||||
|
"families": ["qwen2"],
|
||||||
|
"parameter_size": "7.6B",
|
||||||
|
"quantization_level": "Q4_K_M",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/ollama/api/chat")
|
||||||
|
async def ollama_chat(request: OpenWebUIRequest):
|
||||||
|
resp = rag.query(
|
||||||
|
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
|
||||||
|
)
|
||||||
|
if inspect.isasyncgen(resp):
|
||||||
|
|
||||||
|
async def ollama_resp(chunks):
|
||||||
|
async for chunk in chunks:
|
||||||
|
yield (
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"created_at": datetime.now(timezone.utc).strftime(
|
||||||
|
"%Y-%m-%dT%H:%M:%S.%fZ"
|
||||||
|
),
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": chunk,
|
||||||
|
},
|
||||||
|
"done": False,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
).encode("utf-8")
|
||||||
|
+ b"\n"
|
||||||
|
) # the b"\n" is important
|
||||||
|
|
||||||
|
return StreamingResponse(ollama_resp(resp), media_type="application/json")
|
||||||
|
else:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8020)
|
@@ -1,4 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import ollama_model_complete, ollama_embedding
|
from lightrag.llm import ollama_model_complete, ollama_embedding
|
||||||
@@ -49,3 +51,20 @@ print(
|
|||||||
print(
|
print(
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# stream response
|
||||||
|
resp = rag.query(
|
||||||
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode="hybrid", stream=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def print_stream(stream):
|
||||||
|
async for chunk in stream:
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if inspect.isasyncgen(resp):
|
||||||
|
asyncio.run(print_stream(resp))
|
||||||
|
else:
|
||||||
|
print(resp)
|
||||||
|
@@ -19,6 +19,7 @@ class QueryParam:
|
|||||||
only_need_context: bool = False
|
only_need_context: bool = False
|
||||||
only_need_prompt: bool = False
|
only_need_prompt: bool = False
|
||||||
response_type: str = "Multiple Paragraphs"
|
response_type: str = "Multiple Paragraphs"
|
||||||
|
stream: bool = False
|
||||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
||||||
top_k: int = 60
|
top_k: int = 60
|
||||||
# Number of document chunks to retrieve.
|
# Number of document chunks to retrieve.
|
||||||
|
@@ -27,7 +27,7 @@ from tenacity import (
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Callable, Any
|
from typing import List, Dict, Callable, Any, Union
|
||||||
from .base import BaseKVStorage
|
from .base import BaseKVStorage
|
||||||
from .utils import (
|
from .utils import (
|
||||||
compute_args_hash,
|
compute_args_hash,
|
||||||
@@ -37,6 +37,13 @@ from .utils import (
|
|||||||
get_best_cached_response,
|
get_best_cached_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.version_info < (3, 9):
|
||||||
|
from typing import AsyncIterator
|
||||||
|
else:
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -454,7 +461,8 @@ async def ollama_model_if_cache(
|
|||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
history_messages=[],
|
history_messages=[],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
|
stream = True if kwargs.get("stream") else False
|
||||||
kwargs.pop("max_tokens", None)
|
kwargs.pop("max_tokens", None)
|
||||||
# kwargs.pop("response_format", None) # allow json
|
# kwargs.pop("response_format", None) # allow json
|
||||||
host = kwargs.pop("host", None)
|
host = kwargs.pop("host", None)
|
||||||
@@ -494,28 +502,39 @@ async def ollama_model_if_cache(
|
|||||||
return if_cache_return["return"]
|
return if_cache_return["return"]
|
||||||
|
|
||||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||||
|
if stream:
|
||||||
|
""" cannot cache stream response """
|
||||||
|
|
||||||
result = response["message"]["content"]
|
async def inner():
|
||||||
|
async for chunk in response:
|
||||||
|
yield chunk["message"]["content"]
|
||||||
|
|
||||||
if hashing_kv is not None:
|
return inner()
|
||||||
await hashing_kv.upsert(
|
else:
|
||||||
{
|
result = response["message"]["content"]
|
||||||
args_hash: {
|
if hashing_kv is not None:
|
||||||
"return": result,
|
await hashing_kv.upsert(
|
||||||
"model": model,
|
{
|
||||||
"embedding": quantized.tobytes().hex()
|
args_hash: {
|
||||||
if is_embedding_cache_enabled
|
"return": result,
|
||||||
else None,
|
"model": model,
|
||||||
"embedding_shape": quantized.shape
|
"embedding": quantized.tobytes().hex()
|
||||||
if is_embedding_cache_enabled
|
if is_embedding_cache_enabled
|
||||||
else None,
|
else None,
|
||||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
"embedding_shape": quantized.shape
|
||||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
if is_embedding_cache_enabled
|
||||||
"original_prompt": prompt,
|
else None,
|
||||||
|
"embedding_min": min_val
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_max": max_val
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
)
|
return result
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
@@ -785,7 +804,7 @@ async def hf_model_complete(
|
|||||||
|
|
||||||
async def ollama_model_complete(
|
async def ollama_model_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||||
if keyword_extraction:
|
if keyword_extraction:
|
||||||
kwargs["format"] = "json"
|
kwargs["format"] = "json"
|
||||||
|
@@ -534,8 +534,9 @@ async def kg_query(
|
|||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
|
stream=query_param.stream,
|
||||||
)
|
)
|
||||||
if len(response) > len(sys_prompt):
|
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||||
response = (
|
response = (
|
||||||
response.replace(sys_prompt, "")
|
response.replace(sys_prompt, "")
|
||||||
.replace("user", "")
|
.replace("user", "")
|
||||||
|
Reference in New Issue
Block a user