Merge pull request #650 from danielaskdd/Add-history-support-for-ollama-api

Add history support for ollama api
This commit is contained in:
zrguo
2025-01-27 06:34:10 +08:00
committed by GitHub
6 changed files with 122 additions and 183 deletions

View File

@@ -43,6 +43,9 @@ MAX_ASYNC=4
MAX_TOKENS=32768 MAX_TOKENS=32768
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024
MAX_EMBED_TOKENS=8192 MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3
#CHUNK_SIZE=1200
#CHUNK_OVERLAP_SIZE=100
# Security (empty for no key) # Security (empty for no key)
LIGHTRAG_API_KEY=your-secure-api-key-here LIGHTRAG_API_KEY=your-secure-api-key-here

View File

@@ -1,140 +0,0 @@
from datetime import datetime, timezone
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import inspect
import json
from pydantic import BaseModel
from typing import Optional
import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
from lightrag.utils import EmbeddingFunc
import nest_asyncio
WORKING_DIR = "./dickens"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="qwen2.5:latest",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embed(
texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
),
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
app = FastAPI(title="LightRAG", description="LightRAG API open-webui")
# Data models
MODEL_NAME = "LightRAG:latest"
class Message(BaseModel):
role: Optional[str] = None
content: str
class OpenWebUIRequest(BaseModel):
stream: Optional[bool] = None
model: Optional[str] = None
messages: list[Message]
# API routes
@app.get("/")
async def index():
return "Set Ollama link to http://ip:port/ollama in Open-WebUI Settings"
@app.get("/ollama/api/version")
async def ollama_version():
return {"version": "0.4.7"}
@app.get("/ollama/api/tags")
async def ollama_tags():
return {
"models": [
{
"name": MODEL_NAME,
"model": MODEL_NAME,
"modified_at": "2024-11-12T20:22:37.561463923+08:00",
"size": 4683087332,
"digest": "845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e",
"details": {
"parent_model": "",
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M",
},
}
]
}
@app.post("/ollama/api/chat")
async def ollama_chat(request: OpenWebUIRequest):
resp = rag.query(
request.messages[-1].content, param=QueryParam(mode="hybrid", stream=True)
)
if inspect.isasyncgen(resp):
async def ollama_resp(chunks):
async for chunk in chunks:
yield (
json.dumps(
{
"model": MODEL_NAME,
"created_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
),
"message": {
"role": "assistant",
"content": chunk,
},
"done": False,
},
ensure_ascii=False,
).encode("utf-8")
+ b"\n"
) # the b"\n" is important
return StreamingResponse(ollama_resp(resp), media_type="application/json")
else:
return resp
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)

View File

@@ -94,8 +94,6 @@ For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode q
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
To prevent Open WebUI from using LightRAG when generating conversation titles, go to Admin Panel > Interface > Set Task Model and change both Local Models and External Models to any option except "Current Model".
## Configuration ## Configuration
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables. LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.

View File

@@ -17,6 +17,7 @@ import shutil
import aiofiles import aiofiles
from ascii_colors import trace_exception, ASCIIColors from ascii_colors import trace_exception, ASCIIColors
import os import os
import sys
import configparser import configparser
from fastapi import Depends, Security from fastapi import Depends, Security
@@ -200,8 +201,14 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.max_async}") ASCIIColors.yellow(f"{args.max_async}")
ASCIIColors.white(" ├─ Max Tokens: ", end="") ASCIIColors.white(" ├─ Max Tokens: ", end="")
ASCIIColors.yellow(f"{args.max_tokens}") ASCIIColors.yellow(f"{args.max_tokens}")
ASCIIColors.white(" ─ Max Embed Tokens: ", end="") ASCIIColors.white(" ─ Max Embed Tokens: ", end="")
ASCIIColors.yellow(f"{args.max_embed_tokens}") ASCIIColors.yellow(f"{args.max_embed_tokens}")
ASCIIColors.white(" ├─ Chunk Size: ", end="")
ASCIIColors.yellow(f"{args.chunk_size}")
ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="")
ASCIIColors.yellow(f"{args.chunk_overlap_size}")
ASCIIColors.white(" └─ History Turns: ", end="")
ASCIIColors.yellow(f"{args.history_turns}")
# System Configuration # System Configuration
ASCIIColors.magenta("\n🛠️ System Configuration:") ASCIIColors.magenta("\n🛠️ System Configuration:")
@@ -281,6 +288,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.green("Server is ready to accept connections! 🚀\n") ASCIIColors.green("Server is ready to accept connections! 🚀\n")
# Ensure splash output flush to system log
sys.stdout.flush()
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
""" """
@@ -294,7 +304,7 @@ def parse_args() -> argparse.Namespace:
description="LightRAG FastAPI Server with separate working and input directories" description="LightRAG FastAPI Server with separate working and input directories"
) )
# Bindings (with env var support) # Bindings configuration
parser.add_argument( parser.add_argument(
"--llm-binding", "--llm-binding",
default=get_env_value("LLM_BINDING", "ollama"), default=get_env_value("LLM_BINDING", "ollama"),
@@ -306,9 +316,6 @@ def parse_args() -> argparse.Namespace:
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
) )
# Parse temporary args for host defaults
temp_args, _ = parser.parse_known_args()
# Server configuration # Server configuration
parser.add_argument( parser.add_argument(
"--host", "--host",
@@ -335,13 +342,13 @@ def parse_args() -> argparse.Namespace:
) )
# LLM Model configuration # LLM Model configuration
default_llm_host = get_env_value(
"LLM_BINDING_HOST", get_default_host(temp_args.llm_binding)
)
parser.add_argument( parser.add_argument(
"--llm-binding-host", "--llm-binding-host",
default=default_llm_host, default=get_env_value("LLM_BINDING_HOST", None),
help=f"llm server host URL (default: from env or {default_llm_host})", help="LLM server host URL. If not provided, defaults based on llm-binding:\n"
+ "- ollama: http://localhost:11434\n"
+ "- lollms: http://localhost:9600\n"
+ "- openai: https://api.openai.com/v1",
) )
default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None) default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None)
@@ -359,13 +366,13 @@ def parse_args() -> argparse.Namespace:
) )
# Embedding model configuration # Embedding model configuration
default_embedding_host = get_env_value(
"EMBEDDING_BINDING_HOST", get_default_host(temp_args.embedding_binding)
)
parser.add_argument( parser.add_argument(
"--embedding-binding-host", "--embedding-binding-host",
default=default_embedding_host, default=get_env_value("EMBEDDING_BINDING_HOST", None),
help=f"embedding server host URL (default: from env or {default_embedding_host})", help="Embedding server host URL. If not provided, defaults based on embedding-binding:\n"
+ "- ollama: http://localhost:11434\n"
+ "- lollms: http://localhost:9600\n"
+ "- openai: https://api.openai.com/v1",
) )
default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "") default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
@@ -383,14 +390,14 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--chunk_size", "--chunk_size",
default=1200, default=get_env_value("CHUNK_SIZE", 1200),
help="chunk token size default 1200", help="chunk chunk size default 1200",
) )
parser.add_argument( parser.add_argument(
"--chunk_overlap_size", "--chunk_overlap_size",
default=100, default=get_env_value("CHUNK_OVERLAP_SIZE", 100),
help="chunk token size default 1200", help="chunk overlap size default 100",
) )
def timeout_type(value): def timeout_type(value):
@@ -470,6 +477,13 @@ def parse_args() -> argparse.Namespace:
help="Enable automatic scanning when the program starts", help="Enable automatic scanning when the program starts",
) )
parser.add_argument(
"--history-turns",
type=int,
default=get_env_value("HISTORY_TURNS", 3, int),
help="Number of conversation history turns to include (default: from env or 3)",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -634,8 +648,7 @@ def get_api_key_dependency(api_key: Optional[str]):
def create_app(args): def create_app(args):
# Verify that bindings arer correctly setup # Verify that bindings are correctly setup
if args.llm_binding not in [ if args.llm_binding not in [
"lollms", "lollms",
"ollama", "ollama",
@@ -648,6 +661,13 @@ def create_app(args):
if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]: if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]:
raise Exception("embedding binding not supported") raise Exception("embedding binding not supported")
# Set default hosts if not provided
if args.llm_binding_host is None:
args.llm_binding_host = get_default_host(args.llm_binding)
if args.embedding_binding_host is None:
args.embedding_binding_host = get_default_host(args.embedding_binding)
# Add SSL validation # Add SSL validation
if args.ssl: if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile: if not args.ssl_certfile or not args.ssl_keyfile:
@@ -1442,7 +1462,10 @@ def create_app(args):
@app.post("/api/generate") @app.post("/api/generate")
async def generate(raw_request: Request, request: OllamaGenerateRequest): async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests""" """Handle generate completion requests
For compatiblity purpuse, the request is not processed by LightRAG,
and will be handled by underlying LLM model.
"""
try: try:
query = request.prompt query = request.prompt
start_time = time.time_ns() start_time = time.time_ns()
@@ -1581,15 +1604,22 @@ def create_app(args):
@app.post("/api/chat") @app.post("/api/chat")
async def chat(raw_request: Request, request: OllamaChatRequest): async def chat(raw_request: Request, request: OllamaChatRequest):
"""Handle chat completion requests""" """Process chat completion requests.
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
"""
try: try:
# Get all messages # Get all messages
messages = request.messages messages = request.messages
if not messages: if not messages:
raise HTTPException(status_code=400, detail="No messages provided") raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query # Get the last message as query and previous messages as history
query = messages[-1].content query = messages[-1].content
# Convert OllamaMessage objects to dictionaries
conversation_history = [
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
]
# Check for query prefix # Check for query prefix
cleaned_query, mode = parse_query_mode(query) cleaned_query, mode = parse_query_mode(query)
@@ -1597,9 +1627,17 @@ def create_app(args):
start_time = time.time_ns() start_time = time.time_ns()
prompt_tokens = estimate_tokens(cleaned_query) prompt_tokens = estimate_tokens(cleaned_query)
query_param = QueryParam( param_dict = {
mode=mode, stream=request.stream, only_need_context=False "mode": mode,
) "stream": request.stream,
"only_need_context": False,
"conversation_history": conversation_history,
}
if args.history_turns is not None:
param_dict["history_turns"] = args.history_turns
query_param = QueryParam(**param_dict)
if request.stream: if request.stream:
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse

View File

@@ -633,11 +633,8 @@ async def kg_query(
# Process conversation history # Process conversation history
history_context = "" history_context = ""
if query_param.conversation_history: if query_param.conversation_history:
recent_history = query_param.conversation_history[ history_context = get_conversation_turns(
-query_param.history_window_size : query_param.conversation_history, query_param.history_turns
]
history_context = "\n".join(
[f"{turn['role']}: {turn['content']}" for turn in recent_history]
) )
sys_prompt_temp = PROMPTS["rag_response"] sys_prompt_temp = PROMPTS["rag_response"]

View File

@@ -104,7 +104,7 @@ DEFAULT_CONFIG = {
"host": "localhost", "host": "localhost",
"port": 9621, "port": 9621,
"model": "lightrag:latest", "model": "lightrag:latest",
"timeout": 30, "timeout": 120,
"max_retries": 3, "max_retries": 3,
"retry_delay": 1, "retry_delay": 1,
}, },
@@ -189,19 +189,32 @@ def get_base_url(endpoint: str = "chat") -> str:
def create_chat_request_data( def create_chat_request_data(
content: str, stream: bool = False, model: str = None content: str,
stream: bool = False,
model: str = None,
conversation_history: List[Dict[str, str]] = None,
history_turns: int = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Create chat request data """Create chat request data
Args: Args:
content: User message content content: User message content
stream: Whether to use streaming response stream: Whether to use streaming response
model: Model name model: Model name
conversation_history: List of previous conversation messages
history_turns: Number of history turns to include
Returns: Returns:
Dictionary containing complete chat request data Dictionary containing complete chat request data
""" """
messages = conversation_history or []
if history_turns is not None and conversation_history:
messages = messages[
-2 * history_turns :
] # Each turn has 2 messages (user + assistant)
messages.append({"role": "user", "content": content})
return { return {
"model": model or CONFIG["server"]["model"], "model": model or CONFIG["server"]["model"],
"messages": [{"role": "user", "content": content}], "messages": messages,
"stream": stream, "stream": stream,
} }
@@ -259,11 +272,25 @@ def run_test(func: Callable, name: str) -> None:
def test_non_stream_chat() -> None: def test_non_stream_chat() -> None:
"""Test non-streaming call to /api/chat endpoint""" """Test non-streaming call to /api/chat endpoint"""
url = get_base_url() url = get_base_url()
data = create_chat_request_data(
CONFIG["test_cases"]["basic"]["query"], stream=False
)
# Send request # Example conversation history
conversation_history = [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!我是一个AI助手,很高兴为你服务。"},
{"role": "user", "content": "西游记里有几个主要人物?"},
{
"role": "assistant",
"content": "西游记的主要人物有唐僧、孙悟空、猪八戒、沙和尚这四位主角。",
},
]
# Send request with conversation history and history turns
data = create_chat_request_data(
CONFIG["test_cases"]["basic"]["query"],
stream=False,
conversation_history=conversation_history,
history_turns=2, # Only include last 2 turns
)
response = make_request(url, data) response = make_request(url, data)
# Print response # Print response
@@ -297,9 +324,25 @@ def test_stream_chat() -> None:
The last message will contain performance statistics, with done set to true. The last message will contain performance statistics, with done set to true.
""" """
url = get_base_url() url = get_base_url()
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
# Send request and get streaming response # Example conversation history
conversation_history = [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好!我是一个AI助手,很高兴为你服务。"},
{"role": "user", "content": "西游记里有几个主要人物?"},
{
"role": "assistant",
"content": "西游记的主要人物有唐僧、孙悟空、猪八戒、沙和尚这四位主角。",
},
]
# Send request with conversation history and history turns
data = create_chat_request_data(
CONFIG["test_cases"]["basic"]["query"],
stream=True,
conversation_history=conversation_history,
history_turns=2, # Only include last 2 turns
)
response = make_request(url, data, stream=True) response = make_request(url, data, stream=True)
if OutputControl.is_verbose(): if OutputControl.is_verbose():