Merge pull request #592 from danielaskdd/yangdx

Add Ollama compatible API server
This commit is contained in:
zrguo
2025-01-17 14:29:31 +08:00
committed by GitHub
6 changed files with 1502 additions and 2 deletions

View File

@@ -716,7 +716,7 @@ Output the results in the following structure:
```
</details>
### Batch Eval
### Batch Eval
To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
<details>
@@ -767,6 +767,7 @@ Output your evaluation in the following JSON format:
</details>
### Overall Performance Table
| | **Agriculture** | | **CS** | | **Legal** | | **Mix** | |
|----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|
| | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** |

View File

@@ -0,0 +1,924 @@
from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request
from pydantic import BaseModel
import logging
import argparse
import json
import time
import re
from typing import List, Dict, Any, Optional
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, ollama_embedding
from lightrag.utils import EmbeddingFunc
from enum import Enum
from pathlib import Path
import shutil
import aiofiles
from ascii_colors import trace_exception
import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
from dotenv import load_dotenv
load_dotenv()
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text
Chinese characters: approximately 1.5 tokens per character
English characters: approximately 0.25 tokens per character
"""
# Use regex to match Chinese and non-Chinese characters separately
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
# Calculate estimated token count
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
return int(tokens)
# Constants for model information
LIGHTRAG_NAME = "lightrag"
LIGHTRAG_TAG = "latest"
LIGHTRAG_MODEL = "lightrag:latest"
LIGHTRAG_SIZE = 7365960935
LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z"
LIGHTRAG_DIGEST = "sha256:lightrag"
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
"deepseek-chat",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("DEEPSEEK_API_KEY"),
base_url=os.getenv("DEEPSEEK_ENDPOINT"),
**kwargs,
)
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": "http://m4.lan.znipower.com:11434",
"lollms": "http://localhost:9600",
"azure_openai": "https://api.openai.com/v1",
"openai": os.getenv("DEEPSEEK_ENDPOINT"),
}
return default_hosts.get(
binding_type, "http://localhost:11434"
) # fallback to ollama if unknown
def parse_args():
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Start by the bindings
parser.add_argument(
"--llm-binding",
default="ollama",
help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)",
)
parser.add_argument(
"--embedding-binding",
default="ollama",
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)",
)
# Parse just these arguments first
temp_args, _ = parser.parse_known_args()
# Add remaining arguments with dynamic defaults for hosts
# Server configuration
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
)
# Directory configuration
parser.add_argument(
"--working-dir",
default="./rag_storage",
help="Working directory for RAG storage (default: ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
)
# LLM Model configuration
default_llm_host = get_default_host(temp_args.llm_binding)
parser.add_argument(
"--llm-binding-host",
default=default_llm_host,
help=f"llm server host URL (default: {default_llm_host})",
)
parser.add_argument(
"--llm-model",
default="mistral-nemo:latest",
help="LLM model name (default: mistral-nemo:latest)",
)
# Embedding model configuration
default_embedding_host = get_default_host(temp_args.embedding_binding)
parser.add_argument(
"--embedding-binding-host",
default=default_embedding_host,
help=f"embedding server host URL (default: {default_embedding_host})",
)
parser.add_argument(
"--embedding-model",
default="bge-m3:latest",
help="Embedding model name (default: bge-m3:latest)",
)
def timeout_type(value):
if value is None or value == "None":
return None
return int(value)
parser.add_argument(
"--timeout",
default=None,
type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
)
# RAG configuration
parser.add_argument(
"--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
)
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum token size (default: 32768)",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=1024,
help="Embedding dimensions (default: 1024)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
)
parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
# Optional https parameters
parser.add_argument(
"--ssl", action="store_true", help="Enable HTTPS (default: False)"
)
parser.add_argument(
"--ssl-certfile",
default=None,
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=None,
help="Path to SSL private key file (required if --ssl is enabled)",
)
return parser.parse_args()
class DocumentManager:
"""Handles document operations and tracking"""
def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")):
self.input_dir = Path(input_dir)
self.supported_extensions = supported_extensions
self.indexed_files = set()
# Create input directory if it doesn't exist
self.input_dir.mkdir(parents=True, exist_ok=True)
def scan_directory(self) -> List[Path]:
"""Scan input directory for new files"""
new_files = []
for ext in self.supported_extensions:
for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
def mark_as_indexed(self, file_path: Path):
"""Mark a file as indexed"""
self.indexed_files.add(file_path)
def is_supported_file(self, filename: str) -> bool:
"""Check if file type is supported"""
return any(filename.lower().endswith(ext) for ext in self.supported_extensions)
# Pydantic models
class SearchMode(str, Enum):
naive = "naive"
local = "local"
global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global"
hybrid = "hybrid"
mix = "mix"
# Ollama API compatible models
class OllamaMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
class OllamaChatRequest(BaseModel):
model: str = LIGHTRAG_MODEL
messages: List[OllamaMessage]
stream: bool = True # Default to streaming mode
options: Optional[Dict[str, Any]] = None
class OllamaChatResponse(BaseModel):
model: str
created_at: str
message: OllamaMessage
done: bool
class OllamaVersionResponse(BaseModel):
version: str
class OllamaModelDetails(BaseModel):
parent_model: str
format: str
family: str
families: List[str]
parameter_size: str
quantization_level: str
class OllamaModel(BaseModel):
name: str
model: str
size: int
digest: str
modified_at: str
details: OllamaModelDetails
class OllamaTagResponse(BaseModel):
models: List[OllamaModel]
# Original LightRAG models
class QueryRequest(BaseModel):
query: str
mode: SearchMode = SearchMode.hybrid
stream: bool = False
only_need_context: bool = False
class QueryResponse(BaseModel):
response: str
class InsertTextRequest(BaseModel):
text: str
description: Optional[str] = None
class InsertResponse(BaseModel):
status: str
message: str
document_count: int
def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None
return no_auth
# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value
return api_key_auth
def create_app(args):
# Verify that bindings arer correctly setup
if args.llm_binding not in ["lollms", "ollama", "openai"]:
raise Exception("llm binding not supported")
if args.embedding_binding not in ["lollms", "ollama", "openai"]:
raise Exception("embedding binding not supported")
# Add SSL validation
if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile:
raise Exception(
"SSL certificate and key files must be provided when SSL is enabled"
)
if not os.path.exists(args.ssl_certfile):
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.1",
openapi_tags=[{"name": "api"}],
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
# Initialize RAG
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="bge-m3:latest",
host="http://m4.lan.znipower.com:11434",
),
),
)
@app.on_event("startup")
async def startup_event():
"""Index all files in input directory during startup"""
try:
new_files = doc_manager.scan_directory()
for file_path in new_files:
try:
# Use async file reading
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
content = await f.read()
# Use the async version of insert directly
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
logging.info(f"Indexed file: {file_path}")
except Exception as e:
trace_exception(e)
logging.error(f"Error indexing file {file_path}: {str(e)}")
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
new_files = doc_manager.scan_directory()
indexed_count = 0
for file_path in new_files:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
indexed_count += 1
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
return {
"status": "success",
"indexed_count": indexed_count,
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
if not doc_manager.is_supported_file(file.filename):
raise HTTPException(
status_code=400,
detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}",
)
file_path = doc_manager.input_dir / file.filename
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Immediately index the uploaded file
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
await rag.ainsert(content)
doc_manager.mark_as_indexed(file_path)
return {
"status": "success",
"message": f"File uploaded and indexed: {file.filename}",
"total_documents": len(doc_manager.indexed_files),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
request.query,
param=QueryParam(
mode=request.mode,
stream=request.stream,
only_need_context=request.only_need_context,
),
)
# If response is a string (e.g. cache hit), return directly
if isinstance(response, str):
return QueryResponse(response=response)
# If it's an async generator, decide whether to stream based on stream parameter
if request.stream:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
else:
result = ""
async for chunk in response:
result += chunk
return QueryResponse(response=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = await rag.aquery( # Use aquery instead of query, and add await
request.query,
param=QueryParam(
mode=request.mode,
stream=True,
only_need_context=request.only_need_context,
),
)
from fastapi.responses import StreamingResponse
async def stream_generator():
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"
else:
# If it's an async generator, send chunks one by one
try:
async for chunk in response:
if chunk: # Only send non-empty content
yield f"{json.dumps({'response': chunk})}\n"
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
"X-Accel-Buffering": "no", # Disable Nginx buffering
},
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
await rag.ainsert(request.text)
return InsertResponse(
status="success",
message="Text successfully inserted",
document_count=1,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
await rag.ainsert(text)
else:
raise HTTPException(
status_code=400,
detail="Unsupported file type. Only .txt and .md files are supported",
)
return InsertResponse(
status="success",
message=f"File '{file.filename}' successfully inserted",
document_count=1,
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File encoding not supported")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
failed_files = []
for file in files:
try:
content = await file.read()
if file.filename.endswith((".txt", ".md")):
text = content.decode("utf-8")
await rag.ainsert(text)
inserted_count += 1
else:
failed_files.append(f"{file.filename} (unsupported type)")
except Exception as e:
failed_files.append(f"{file.filename} ({str(e)})")
status_message = f"Successfully inserted {inserted_count} documents"
if failed_files:
status_message += f". Failed files: {', '.join(failed_files)}"
return InsertResponse(
status="success" if inserted_count > 0 else "partial_success",
message=status_message,
document_count=len(files),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
rag.entities_vdb = None
rag.relationships_vdb = None
return InsertResponse(
status="success",
message="All documents cleared successfully",
document_count=0,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Ollama compatible API endpoints
@app.get("/api/version")
async def get_version():
"""Get Ollama version information"""
return OllamaVersionResponse(version="0.5.4")
@app.get("/api/tags")
async def get_tags():
"""Get available models"""
return OllamaTagResponse(
models=[
{
"name": LIGHTRAG_MODEL,
"model": LIGHTRAG_MODEL,
"size": LIGHTRAG_SIZE,
"digest": LIGHTRAG_DIGEST,
"modified_at": LIGHTRAG_CREATED_AT,
"details": {
"parent_model": "",
"format": "gguf",
"family": LIGHTRAG_NAME,
"families": [LIGHTRAG_NAME],
"parameter_size": "13B",
"quantization_level": "Q4_0",
},
}
]
)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:
"""Parse query prefix to determine search mode
Returns tuple of (cleaned_query, search_mode)
"""
mode_map = {
"/local ": SearchMode.local,
"/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword
"/naive ": SearchMode.naive,
"/hybrid ": SearchMode.hybrid,
"/mix ": SearchMode.mix,
}
for prefix, mode in mode_map.items():
if query.startswith(prefix):
# After removing prefix an leading spaces
cleaned_query = query[len(prefix) :].lstrip()
return cleaned_query, mode
return query, SearchMode.hybrid
@app.post("/api/chat")
async def chat(raw_request: Request, request: OllamaChatRequest):
"""Handle chat completion requests"""
try:
# Get all messages
messages = request.messages
if not messages:
raise HTTPException(status_code=400, detail="No messages provided")
# Get the last message as query
query = messages[-1].content
# 解析查询模式
cleaned_query, mode = parse_query_mode(query)
# 开始计时
start_time = time.time_ns()
# 计算输入token数量
prompt_tokens = estimate_tokens(cleaned_query)
# 调用RAG进行查询
query_param = QueryParam(
mode=mode, stream=request.stream, only_need_context=False
)
if request.stream:
from fastapi.responses import StreamingResponse
response = await rag.aquery( # Need await to get async generator
cleaned_query, param=query_param
)
async def stream_generator():
try:
first_chunk_time = None
last_chunk_time = None
total_response = ""
# Ensure response is an async generator
if isinstance(response, str):
# If it's a string, send in two parts
first_chunk_time = time.time_ns()
last_chunk_time = first_chunk_time
total_response = response
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": response,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
else:
async for chunk in response:
if chunk:
if first_chunk_time is None:
first_chunk_time = time.time_ns()
last_chunk_time = time.time_ns()
total_response += chunk
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": chunk,
"images": None,
},
"done": False,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
completion_tokens = estimate_tokens(total_response)
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
data = {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
yield f"{json.dumps(data, ensure_ascii=False)}\n"
return # Ensure the generator ends immediately after sending the completion marker
except Exception as e:
logging.error(f"Error in stream_generator: {str(e)}")
raise
return StreamingResponse(
stream_generator(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "application/x-ndjson",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
first_chunk_time = time.time_ns()
response_text = await rag.aquery(cleaned_query, param=query_param)
last_chunk_time = time.time_ns()
if not response_text:
response_text = "No response generated"
completion_tokens = estimate_tokens(str(response_text))
total_time = last_chunk_time - start_time
prompt_eval_time = first_chunk_time - start_time
eval_time = last_chunk_time - first_chunk_time
return {
"model": LIGHTRAG_MODEL,
"created_at": LIGHTRAG_CREATED_AT,
"message": {
"role": "assistant",
"content": str(response_text),
"images": None,
},
"done": True,
"total_duration": total_time,
"load_duration": 0,
"prompt_eval_count": prompt_tokens,
"prompt_eval_duration": prompt_eval_time,
"eval_count": completion_tokens,
"eval_duration": eval_time,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {
"status": "healthy",
"working_directory": str(args.working_dir),
"input_directory": str(args.input_dir),
"indexed_files": len(doc_manager.indexed_files),
"configuration": {
# LLM configuration binding/host address (if applicable)/model (if applicable)
"llm_binding": args.llm_binding,
"llm_binding_host": args.llm_binding_host,
"llm_model": args.llm_model,
# embedding model configuration binding/host address (if applicable)/model (if applicable)
"embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
},
}
return app
def main():
args = parse_args()
import uvicorn
app = create_app(args)
uvicorn_config = {
"app": app,
"host": args.host,
"port": args.port,
}
if args.ssl:
uvicorn_config.update(
{
"ssl_certfile": args.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile,
}
)
uvicorn.run(**uvicorn_config)
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,6 @@
aioboto3
ascii_colors
fastapi
lightrag-hku
nano_vectordb
nest_asyncio
numpy

View File

@@ -101,6 +101,7 @@ setuptools.setup(
entry_points={
"console_scripts": [
"lightrag-server=lightrag.api.lightrag_server:main [api]",
"lightrag-ollama=lightrag.api.lightrag_ollama:main [api]",
],
},
)

3
start-server.sh Executable file
View File

@@ -0,0 +1,3 @@
. venv/bin/activate
lightrag-ollama --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024

View File

@@ -0,0 +1,572 @@
"""
LightRAG Ollama Compatibility Interface Test Script
This script tests the LightRAG's Ollama compatibility interface, including:
1. Basic functionality tests (streaming and non-streaming responses)
2. Query mode tests (local, global, naive, hybrid)
3. Error handling tests (including streaming and non-streaming scenarios)
All responses use the JSON Lines format, complying with the Ollama API specification.
"""
import requests
import json
import argparse
import time
from typing import Dict, Any, Optional, List, Callable
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
class OutputControl:
"""Output control class, manages the verbosity of test output"""
_verbose: bool = False
@classmethod
def set_verbose(cls, verbose: bool) -> None:
cls._verbose = verbose
@classmethod
def is_verbose(cls) -> bool:
return cls._verbose
@dataclass
class TestResult:
"""Test result data class"""
name: str
success: bool
duration: float
error: Optional[str] = None
timestamp: str = ""
def __post_init__(self):
if not self.timestamp:
self.timestamp = datetime.now().isoformat()
class TestStats:
"""Test statistics"""
def __init__(self):
self.results: List[TestResult] = []
self.start_time = datetime.now()
def add_result(self, result: TestResult):
self.results.append(result)
def export_results(self, path: str = "test_results.json"):
"""Export test results to a JSON file
Args:
path: Output file path
"""
results_data = {
"start_time": self.start_time.isoformat(),
"end_time": datetime.now().isoformat(),
"results": [asdict(r) for r in self.results],
"summary": {
"total": len(self.results),
"passed": sum(1 for r in self.results if r.success),
"failed": sum(1 for r in self.results if not r.success),
"total_duration": sum(r.duration for r in self.results),
},
}
with open(path, "w", encoding="utf-8") as f:
json.dump(results_data, f, ensure_ascii=False, indent=2)
print(f"\nTest results saved to: {path}")
def print_summary(self):
total = len(self.results)
passed = sum(1 for r in self.results if r.success)
failed = total - passed
duration = sum(r.duration for r in self.results)
print("\n=== Test Summary ===")
print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total duration: {duration:.2f} seconds")
print(f"Total tests: {total}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
if failed > 0:
print("\nFailed tests:")
for result in self.results:
if not result.success:
print(f"- {result.name}: {result.error}")
DEFAULT_CONFIG = {
"server": {
"host": "localhost",
"port": 9621,
"model": "lightrag:latest",
"timeout": 30,
"max_retries": 3,
"retry_delay": 1,
},
"test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
}
def make_request(
url: str, data: Dict[str, Any], stream: bool = False
) -> requests.Response:
"""Send an HTTP request with retry mechanism
Args:
url: Request URL
data: Request data
stream: Whether to use streaming response
Returns:
requests.Response: Response object
Raises:
requests.exceptions.RequestException: Request failed after all retries
"""
server_config = CONFIG["server"]
max_retries = server_config["max_retries"]
retry_delay = server_config["retry_delay"]
timeout = server_config["timeout"]
for attempt in range(max_retries):
try:
response = requests.post(url, json=data, stream=stream, timeout=timeout)
return response
except requests.exceptions.RequestException as e:
if attempt == max_retries - 1: # Last retry
raise
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
time.sleep(retry_delay)
def load_config() -> Dict[str, Any]:
"""Load configuration file
First try to load from config.json in the current directory,
if it doesn't exist, use the default configuration
Returns:
Configuration dictionary
"""
config_path = Path("config.json")
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
return DEFAULT_CONFIG
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
"""Format and print JSON response data
Args:
data: Data dictionary to print
title: Title to print
indent: Number of spaces for JSON indentation
"""
if OutputControl.is_verbose():
if title:
print(f"\n=== {title} ===")
print(json.dumps(data, ensure_ascii=False, indent=indent))
# Global configuration
CONFIG = load_config()
def get_base_url() -> str:
"""Return the base URL"""
server = CONFIG["server"]
return f"http://{server['host']}:{server['port']}/api/chat"
def create_request_data(
content: str, stream: bool = False, model: str = None
) -> Dict[str, Any]:
"""Create basic request data
Args:
content: User message content
stream: Whether to use streaming response
model: Model name
Returns:
Dictionary containing complete request data
"""
return {
"model": model or CONFIG["server"]["model"],
"messages": [{"role": "user", "content": content}],
"stream": stream,
}
# Global test statistics
STATS = TestStats()
def run_test(func: Callable, name: str) -> None:
"""Run a test and record the results
Args:
func: Test function
name: Test name
"""
start_time = time.time()
try:
func()
duration = time.time() - start_time
STATS.add_result(TestResult(name, True, duration))
except Exception as e:
duration = time.time() - start_time
STATS.add_result(TestResult(name, False, duration, str(e)))
raise
def test_non_stream_chat():
"""Test non-streaming call to /api/chat endpoint"""
url = get_base_url()
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
# Send request
response = make_request(url, data)
# Print response
if OutputControl.is_verbose():
print("\n=== Non-streaming call response ===")
response_json = response.json()
# Print response content
print_json_response(
{"model": response_json["model"], "message": response_json["message"]},
"Response content",
)
def test_stream_chat():
"""Test streaming call to /api/chat endpoint
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
Response format:
{
"model": "lightrag:latest",
"created_at": "2024-01-15T00:00:00Z",
"message": {
"role": "assistant",
"content": "Partial response content",
"images": null
},
"done": false
}
The last message will contain performance statistics, with done set to true.
"""
url = get_base_url()
data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
# Send request and get streaming response
response = make_request(url, data, stream=True)
if OutputControl.is_verbose():
print("\n=== Streaming call response ===")
output_buffer = []
try:
for line in response.iter_lines():
if line: # Skip empty lines
try:
# Decode and parse JSON
data = json.loads(line.decode("utf-8"))
if data.get("done", True): # If it's the completion marker
if (
"total_duration" in data
): # Final performance statistics message
# print_json_response(data, "Performance statistics")
break
else: # Normal content message
message = data.get("message", {})
content = message.get("content", "")
if content: # Only collect non-empty content
output_buffer.append(content)
print(
content, end="", flush=True
) # Print content in real-time
except json.JSONDecodeError:
print("Error decoding JSON from response line")
finally:
response.close() # Ensure the response connection is closed
# Print a newline
print()
def test_query_modes():
"""Test different query mode prefixes
Supported query modes:
- /local: Local retrieval mode, searches only in highly relevant documents
- /global: Global retrieval mode, searches across all documents
- /naive: Naive mode, does not use any optimization strategies
- /hybrid: Hybrid mode (default), combines multiple strategies
- /mix: Mix mode
Each mode will return responses in the same format, but with different retrieval strategies.
"""
url = get_base_url()
modes = ["local", "global", "naive", "hybrid", "mix"]
for mode in modes:
if OutputControl.is_verbose():
print(f"\n=== Testing /{mode} mode ===")
data = create_request_data(
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
)
# Send request
response = make_request(url, data)
response_json = response.json()
# Print response content
print_json_response(
{"model": response_json["model"], "message": response_json["message"]}
)
def create_error_test_data(error_type: str) -> Dict[str, Any]:
"""Create request data for error testing
Args:
error_type: Error type, supported:
- empty_messages: Empty message list
- invalid_role: Invalid role field
- missing_content: Missing content field
Returns:
Request dictionary containing error data
"""
error_data = {
"empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True},
"invalid_role": {
"model": "lightrag:latest",
"messages": [{"invalid_role": "user", "content": "Test message"}],
"stream": True,
},
"missing_content": {
"model": "lightrag:latest",
"messages": [{"role": "user"}],
"stream": True,
},
}
return error_data.get(error_type, error_data["empty_messages"])
def test_stream_error_handling():
"""Test error handling for streaming responses
Test scenarios:
1. Empty message list
2. Message format error (missing required fields)
Error responses should be returned immediately without establishing a streaming connection.
The status code should be 4xx, and detailed error information should be returned.
"""
url = get_base_url()
if OutputControl.is_verbose():
print("\n=== Testing streaming response error handling ===")
# Test empty message list
if OutputControl.is_verbose():
print("\n--- Testing empty message list (streaming) ---")
data = create_error_test_data("empty_messages")
response = make_request(url, data, stream=True)
print(f"Status code: {response.status_code}")
if response.status_code != 200:
print_json_response(response.json(), "Error message")
response.close()
# Test invalid role field
if OutputControl.is_verbose():
print("\n--- Testing invalid role field (streaming) ---")
data = create_error_test_data("invalid_role")
response = make_request(url, data, stream=True)
print(f"Status code: {response.status_code}")
if response.status_code != 200:
print_json_response(response.json(), "Error message")
response.close()
# Test missing content field
if OutputControl.is_verbose():
print("\n--- Testing missing content field (streaming) ---")
data = create_error_test_data("missing_content")
response = make_request(url, data, stream=True)
print(f"Status code: {response.status_code}")
if response.status_code != 200:
print_json_response(response.json(), "Error message")
response.close()
def test_error_handling():
"""Test error handling for non-streaming responses
Test scenarios:
1. Empty message list
2. Message format error (missing required fields)
Error response format:
{
"detail": "Error description"
}
All errors should return appropriate HTTP status codes and clear error messages.
"""
url = get_base_url()
if OutputControl.is_verbose():
print("\n=== Testing error handling ===")
# Test empty message list
if OutputControl.is_verbose():
print("\n--- Testing empty message list ---")
data = create_error_test_data("empty_messages")
data["stream"] = False # Change to non-streaming mode
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# Test invalid role field
if OutputControl.is_verbose():
print("\n--- Testing invalid role field ---")
data = create_error_test_data("invalid_role")
data["stream"] = False # Change to non-streaming mode
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# Test missing content field
if OutputControl.is_verbose():
print("\n--- Testing missing content field ---")
data = create_error_test_data("missing_content")
data["stream"] = False # Change to non-streaming mode
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
def get_test_cases() -> Dict[str, Callable]:
"""Get all available test cases
Returns:
A dictionary mapping test names to test functions
"""
return {
"non_stream": test_non_stream_chat,
"stream": test_stream_chat,
"modes": test_query_modes,
"errors": test_error_handling,
"stream_errors": test_stream_error_handling,
}
def create_default_config():
"""Create a default configuration file"""
config_path = Path("config.json")
if not config_path.exists():
with open(config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
print(f"Default configuration file created: {config_path}")
def parse_args() -> argparse.Namespace:
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="LightRAG Ollama Compatibility Interface Testing",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Configuration file (config.json):
{
"server": {
"host": "localhost", # Server address
"port": 9621, # Server port
"model": "lightrag:latest" # Default model name
},
"test_cases": {
"basic": {
"query": "Test query", # Basic query text
"stream_query": "Stream query" # Stream query text
}
}
}
""",
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="Silent mode, only display test result summary",
)
parser.add_argument(
"-a",
"--ask",
type=str,
help="Specify query content, which will override the query settings in the configuration file",
)
parser.add_argument(
"--init-config", action="store_true", help="Create default configuration file"
)
parser.add_argument(
"--output",
type=str,
default="",
help="Test result output file path, default is not to output to a file",
)
parser.add_argument(
"--tests",
nargs="+",
choices=list(get_test_cases().keys()) + ["all"],
default=["all"],
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Set output mode
OutputControl.set_verbose(not args.quiet)
# If query content is specified, update the configuration
if args.ask:
CONFIG["test_cases"]["basic"]["query"] = args.ask
# If specified to create a configuration file
if args.init_config:
create_default_config()
exit(0)
test_cases = get_test_cases()
try:
if "all" in args.tests:
# Run all tests
if OutputControl.is_verbose():
print("\n【Basic Functionality Tests】")
run_test(test_non_stream_chat, "Non-streaming Call Test")
run_test(test_stream_chat, "Streaming Call Test")
if OutputControl.is_verbose():
print("\n【Query Mode Tests】")
run_test(test_query_modes, "Query Mode Test")
if OutputControl.is_verbose():
print("\n【Error Handling Tests】")
run_test(test_error_handling, "Error Handling Test")
run_test(test_stream_error_handling, "Streaming Error Handling Test")
else:
# Run specified tests
for test_name in args.tests:
if OutputControl.is_verbose():
print(f"\n【Running Test: {test_name}")
run_test(test_cases[test_name], test_name)
except Exception as e:
print(f"\nAn error occurred: {str(e)}")
finally:
# Print test statistics
STATS.print_summary()
# If an output file path is specified, export the results
if args.output:
STATS.export_results(args.output)