Merge pull request #650 from danielaskdd/Add-history-support-for-ollama-api

Add history support for ollama api
This commit is contained in:
zrguo
2025-01-27 06:34:10 +08:00
committed by GitHub
6 changed files with 122 additions and 183 deletions

View File

@@ -17,6 +17,7 @@ import shutil
import aiofiles
from ascii_colors import trace_exception, ASCIIColors
import os
import sys
import configparser
from fastapi import Depends, Security
@@ -200,8 +201,14 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.max_async}")
ASCIIColors.white(" ├─ Max Tokens: ", end="")
ASCIIColors.yellow(f"{args.max_tokens}")
ASCIIColors.white(" ─ Max Embed Tokens: ", end="")
ASCIIColors.white(" ─ Max Embed Tokens: ", end="")
ASCIIColors.yellow(f"{args.max_embed_tokens}")
ASCIIColors.white(" ├─ Chunk Size: ", end="")
ASCIIColors.yellow(f"{args.chunk_size}")
ASCIIColors.white(" ├─ Chunk Overlap Size: ", end="")
ASCIIColors.yellow(f"{args.chunk_overlap_size}")
ASCIIColors.white(" └─ History Turns: ", end="")
ASCIIColors.yellow(f"{args.history_turns}")
# System Configuration
ASCIIColors.magenta("\n🛠️ System Configuration:")
@@ -281,6 +288,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.green("Server is ready to accept connections! 🚀\n")
# Ensure splash output flush to system log
sys.stdout.flush()
def parse_args() -> argparse.Namespace:
"""
@@ -294,7 +304,7 @@ def parse_args() -> argparse.Namespace:
description="LightRAG FastAPI Server with separate working and input directories"
)
# Bindings (with env var support)
# Bindings configuration
parser.add_argument(
"--llm-binding",
default=get_env_value("LLM_BINDING", "ollama"),
@@ -306,9 +316,6 @@ def parse_args() -> argparse.Namespace:
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
)
# Parse temporary args for host defaults
temp_args, _ = parser.parse_known_args()
# Server configuration
parser.add_argument(
"--host",
@@ -335,13 +342,13 @@ def parse_args() -> argparse.Namespace:
)
# LLM Model configuration
default_llm_host = get_env_value(
"LLM_BINDING_HOST", get_default_host(temp_args.llm_binding)
)
parser.add_argument(
"--llm-binding-host",
default=default_llm_host,
help=f"llm server host URL (default: from env or {default_llm_host})",
default=get_env_value("LLM_BINDING_HOST", None),
help="LLM server host URL. If not provided, defaults based on llm-binding:\n"
+ "- ollama: http://localhost:11434\n"
+ "- lollms: http://localhost:9600\n"
+ "- openai: https://api.openai.com/v1",
)
default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None)
@@ -359,13 +366,13 @@ def parse_args() -> argparse.Namespace:
)
# Embedding model configuration
default_embedding_host = get_env_value(
"EMBEDDING_BINDING_HOST", get_default_host(temp_args.embedding_binding)
)
parser.add_argument(
"--embedding-binding-host",
default=default_embedding_host,
help=f"embedding server host URL (default: from env or {default_embedding_host})",
default=get_env_value("EMBEDDING_BINDING_HOST", None),
help="Embedding server host URL. If not provided, defaults based on embedding-binding:\n"
+ "- ollama: http://localhost:11434\n"
+ "- lollms: http://localhost:9600\n"
+ "- openai: https://api.openai.com/v1",
)
default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
@@ -383,14 +390,14 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--chunk_size",
default=1200,
help="chunk token size default 1200",
default=get_env_value("CHUNK_SIZE", 1200),
help="chunk chunk size default 1200",
)
parser.add_argument(
"--chunk_overlap_size",
default=100,
help="chunk token size default 1200",
default=get_env_value("CHUNK_OVERLAP_SIZE", 100),
help="chunk overlap size default 100",
)
def timeout_type(value):
@@ -470,6 +477,13 @@ def parse_args() -> argparse.Namespace:
help="Enable automatic scanning when the program starts",
)
parser.add_argument(
"--history-turns",
type=int,
default=get_env_value("HISTORY_TURNS", 3, int),
help="Number of conversation history turns to include (default: from env or 3)",
)
args = parser.parse_args()
return args
@@ -634,8 +648,7 @@ def get_api_key_dependency(api_key: Optional[str]):
def create_app(args):
# Verify that bindings arer correctly setup
# Verify that bindings are correctly setup
if args.llm_binding not in [
"lollms",
"ollama",
@@ -648,6 +661,13 @@ def create_app(args):
if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]:
raise Exception("embedding binding not supported")
# Set default hosts if not provided
if args.llm_binding_host is None:
args.llm_binding_host = get_default_host(args.llm_binding)
if args.embedding_binding_host is None:
args.embedding_binding_host = get_default_host(args.embedding_binding)
# Add SSL validation
if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile:
@@ -1442,7 +1462,10 @@ def create_app(args):
@app.post("/api/generate")
async def generate(raw_request: Request, request: OllamaGenerateRequest):
"""Handle generate completion requests"""
"""Handle generate completion requests
For compatiblity purpuse, the request is not processed by LightRAG,
and will be handled by underlying LLM model.
"""
try:
query = request.prompt
start_time = time.time_ns()
@@ -1581,15 +1604,22 @@ def create_app(args):
@app.post("/api/chat")
async def chat(raw_request: Request, request: OllamaChatRequest):
"""Handle chat completion requests"""
"""Process chat completion requests.
Routes user queries through LightRAG by selecting query mode based on prefix indicators.
Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM.
"""
try:
# Get all messages
messages = request.messages
if not messages:
raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query
# Get the last message as query and previous messages as history
query = messages[-1].content
# Convert OllamaMessage objects to dictionaries
conversation_history = [
{"role": msg.role, "content": msg.content} for msg in messages[:-1]
]
# Check for query prefix
cleaned_query, mode = parse_query_mode(query)
@@ -1597,9 +1627,17 @@ def create_app(args):
start_time = time.time_ns()
prompt_tokens = estimate_tokens(cleaned_query)
query_param = QueryParam(
mode=mode, stream=request.stream, only_need_context=False
)
param_dict = {
"mode": mode,
"stream": request.stream,
"only_need_context": False,
"conversation_history": conversation_history,
}
if args.history_turns is not None:
param_dict["history_turns"] = args.history_turns
query_param = QueryParam(**param_dict)
if request.stream:
from fastapi.responses import StreamingResponse