diff --git a/README.md b/README.md index 3ebbe290..db049796 100644 --- a/README.md +++ b/README.md @@ -716,7 +716,7 @@ Output the results in the following structure: ``` - ### Batch Eval +### Batch Eval To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
@@ -767,6 +767,7 @@ Output your evaluation in the following JSON format:
### Overall Performance Table + | | **Agriculture** | | **CS** | | **Legal** | | **Mix** | | |----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------| | | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | diff --git a/lightrag/api/lightrag_ollama.py b/lightrag/api/lightrag_ollama.py new file mode 100644 index 00000000..fc7ae29c --- /dev/null +++ b/lightrag/api/lightrag_ollama.py @@ -0,0 +1,924 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request +from pydantic import BaseModel +import logging +import argparse +import json +import time +import re +from typing import List, Dict, Any, Optional +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, ollama_embedding + +from lightrag.utils import EmbeddingFunc +from enum import Enum +from pathlib import Path +import shutil +import aiofiles +from ascii_colors import trace_exception +import os + +from fastapi import Depends, Security +from fastapi.security import APIKeyHeader +from fastapi.middleware.cors import CORSMiddleware + +from starlette.status import HTTP_403_FORBIDDEN + +from dotenv import load_dotenv + +load_dotenv() + + +def estimate_tokens(text: str) -> int: + """Estimate the number of tokens in text + Chinese characters: approximately 1.5 tokens per character + English characters: approximately 0.25 tokens per character + """ + # Use regex to match Chinese and non-Chinese characters separately + chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) + non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text)) + + # Calculate estimated token count + tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25 + + return int(tokens) + + +# Constants for model information +LIGHTRAG_NAME = "lightrag" +LIGHTRAG_TAG = "latest" +LIGHTRAG_MODEL = "lightrag:latest" +LIGHTRAG_SIZE = 7365960935 +LIGHTRAG_CREATED_AT = "2024-01-15T00:00:00Z" +LIGHTRAG_DIGEST = "sha256:lightrag" + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + return await openai_complete_if_cache( + "deepseek-chat", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("DEEPSEEK_API_KEY"), + base_url=os.getenv("DEEPSEEK_ENDPOINT"), + **kwargs, + ) + + +def get_default_host(binding_type: str) -> str: + default_hosts = { + "ollama": "http://m4.lan.znipower.com:11434", + "lollms": "http://localhost:9600", + "azure_openai": "https://api.openai.com/v1", + "openai": os.getenv("DEEPSEEK_ENDPOINT"), + } + return default_hosts.get( + binding_type, "http://localhost:11434" + ) # fallback to ollama if unknown + + +def parse_args(): + parser = argparse.ArgumentParser( + description="LightRAG FastAPI Server with separate working and input directories" + ) + + # Start by the bindings + parser.add_argument( + "--llm-binding", + default="ollama", + help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)", + ) + parser.add_argument( + "--embedding-binding", + default="ollama", + help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", + ) + + # Parse just these arguments first + temp_args, _ = parser.parse_known_args() + + # Add remaining arguments with dynamic defaults for hosts + # Server configuration + parser.add_argument( + "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=9621, help="Server port (default: 9621)" + ) + + # Directory configuration + parser.add_argument( + "--working-dir", + default="./rag_storage", + help="Working directory for RAG storage (default: ./rag_storage)", + ) + parser.add_argument( + "--input-dir", + default="./inputs", + help="Directory containing input documents (default: ./inputs)", + ) + + # LLM Model configuration + default_llm_host = get_default_host(temp_args.llm_binding) + parser.add_argument( + "--llm-binding-host", + default=default_llm_host, + help=f"llm server host URL (default: {default_llm_host})", + ) + + parser.add_argument( + "--llm-model", + default="mistral-nemo:latest", + help="LLM model name (default: mistral-nemo:latest)", + ) + + # Embedding model configuration + default_embedding_host = get_default_host(temp_args.embedding_binding) + parser.add_argument( + "--embedding-binding-host", + default=default_embedding_host, + help=f"embedding server host URL (default: {default_embedding_host})", + ) + + parser.add_argument( + "--embedding-model", + default="bge-m3:latest", + help="Embedding model name (default: bge-m3:latest)", + ) + + def timeout_type(value): + if value is None or value == "None": + return None + return int(value) + + parser.add_argument( + "--timeout", + default=None, + type=timeout_type, + help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", + ) + # RAG configuration + parser.add_argument( + "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" + ) + parser.add_argument( + "--max-tokens", + type=int, + default=32768, + help="Maximum token size (default: 32768)", + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=1024, + help="Embedding dimensions (default: 1024)", + ) + parser.add_argument( + "--max-embed-tokens", + type=int, + default=8192, + help="Maximum embedding token size (default: 8192)", + ) + + # Logging configuration + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level (default: INFO)", + ) + + parser.add_argument( + "--key", + type=str, + help="API key for authentication. This protects lightrag server against unauthorized access", + default=None, + ) + + # Optional https parameters + parser.add_argument( + "--ssl", action="store_true", help="Enable HTTPS (default: False)" + ) + parser.add_argument( + "--ssl-certfile", + default=None, + help="Path to SSL certificate file (required if --ssl is enabled)", + ) + parser.add_argument( + "--ssl-keyfile", + default=None, + help="Path to SSL private key file (required if --ssl is enabled)", + ) + return parser.parse_args() + + +class DocumentManager: + """Handles document operations and tracking""" + + def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): + self.input_dir = Path(input_dir) + self.supported_extensions = supported_extensions + self.indexed_files = set() + + # Create input directory if it doesn't exist + self.input_dir.mkdir(parents=True, exist_ok=True) + + def scan_directory(self) -> List[Path]: + """Scan input directory for new files""" + new_files = [] + for ext in self.supported_extensions: + for file_path in self.input_dir.rglob(f"*{ext}"): + if file_path not in self.indexed_files: + new_files.append(file_path) + return new_files + + def mark_as_indexed(self, file_path: Path): + """Mark a file as indexed""" + self.indexed_files.add(file_path) + + def is_supported_file(self, filename: str) -> bool: + """Check if file type is supported""" + return any(filename.lower().endswith(ext) for ext in self.supported_extensions) + + +# Pydantic models +class SearchMode(str, Enum): + naive = "naive" + local = "local" + global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global" + hybrid = "hybrid" + mix = "mix" + + +# Ollama API compatible models +class OllamaMessage(BaseModel): + role: str + content: str + images: Optional[List[str]] = None + + +class OllamaChatRequest(BaseModel): + model: str = LIGHTRAG_MODEL + messages: List[OllamaMessage] + stream: bool = True # Default to streaming mode + options: Optional[Dict[str, Any]] = None + + +class OllamaChatResponse(BaseModel): + model: str + created_at: str + message: OllamaMessage + done: bool + + +class OllamaVersionResponse(BaseModel): + version: str + + +class OllamaModelDetails(BaseModel): + parent_model: str + format: str + family: str + families: List[str] + parameter_size: str + quantization_level: str + + +class OllamaModel(BaseModel): + name: str + model: str + size: int + digest: str + modified_at: str + details: OllamaModelDetails + + +class OllamaTagResponse(BaseModel): + models: List[OllamaModel] + + +# Original LightRAG models +class QueryRequest(BaseModel): + query: str + mode: SearchMode = SearchMode.hybrid + stream: bool = False + only_need_context: bool = False + + +class QueryResponse(BaseModel): + response: str + + +class InsertTextRequest(BaseModel): + text: str + description: Optional[str] = None + + +class InsertResponse(BaseModel): + status: str + message: str + document_count: int + + +def get_api_key_dependency(api_key: Optional[str]): + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth + + +def create_app(args): + # Verify that bindings arer correctly setup + if args.llm_binding not in ["lollms", "ollama", "openai"]: + raise Exception("llm binding not supported") + + if args.embedding_binding not in ["lollms", "ollama", "openai"]: + raise Exception("embedding binding not supported") + + # Add SSL validation + if args.ssl: + if not args.ssl_certfile or not args.ssl_keyfile: + raise Exception( + "SSL certificate and key files must be provided when SSL is enabled" + ) + if not os.path.exists(args.ssl_certfile): + raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") + if not os.path.exists(args.ssl_keyfile): + raise Exception(f"SSL key file not found: {args.ssl_keyfile}") + + # Setup logging + logging.basicConfig( + format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) + ) + + # Check if API key is provided either through env var or args + api_key = os.getenv("LIGHTRAG_API_KEY") or args.key + + # Initialize FastAPI + app = FastAPI( + title="LightRAG API", + description="API for querying text using LightRAG with separate storage and input directories" + + "(With authentication)" + if api_key + else "", + version="1.0.1", + openapi_tags=[{"name": "api"}], + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Create the optional API key dependency + optional_api_key = get_api_key_dependency(api_key) + + # Create working directory if it doesn't exist + Path(args.working_dir).mkdir(parents=True, exist_ok=True) + + # Initialize document manager + doc_manager = DocumentManager(args.input_dir) + + # Initialize RAG + rag = LightRAG( + working_dir=args.working_dir, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: ollama_embedding( + texts, + embed_model="bge-m3:latest", + host="http://m4.lan.znipower.com:11434", + ), + ), + ) + + @app.on_event("startup") + async def startup_event(): + """Index all files in input directory during startup""" + try: + new_files = doc_manager.scan_directory() + for file_path in new_files: + try: + # Use async file reading + async with aiofiles.open(file_path, "r", encoding="utf-8") as f: + content = await f.read() + # Use the async version of insert directly + await rag.ainsert(content) + doc_manager.mark_as_indexed(file_path) + logging.info(f"Indexed file: {file_path}") + except Exception as e: + trace_exception(e) + logging.error(f"Error indexing file {file_path}: {str(e)}") + + logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") + + except Exception as e: + logging.error(f"Error during startup indexing: {str(e)}") + + @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) + async def scan_for_new_documents(): + """Manually trigger scanning for new documents""" + try: + new_files = doc_manager.scan_directory() + indexed_count = 0 + + for file_path in new_files: + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + await rag.ainsert(content) + doc_manager.mark_as_indexed(file_path) + indexed_count += 1 + except Exception as e: + logging.error(f"Error indexing file {file_path}: {str(e)}") + + return { + "status": "success", + "indexed_count": indexed_count, + "total_documents": len(doc_manager.indexed_files), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) + async def upload_to_input_dir(file: UploadFile = File(...)): + """Upload a file to the input directory""" + try: + if not doc_manager.is_supported_file(file.filename): + raise HTTPException( + status_code=400, + detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + ) + + file_path = doc_manager.input_dir / file.filename + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + + # Immediately index the uploaded file + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + await rag.ainsert(content) + doc_manager.mark_as_indexed(file_path) + + return { + "status": "success", + "message": f"File uploaded and indexed: {file.filename}", + "total_documents": len(doc_manager.indexed_files), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + ) + async def query_text(request: QueryRequest): + try: + response = await rag.aquery( + request.query, + param=QueryParam( + mode=request.mode, + stream=request.stream, + only_need_context=request.only_need_context, + ), + ) + + # If response is a string (e.g. cache hit), return directly + if isinstance(response, str): + return QueryResponse(response=response) + + # If it's an async generator, decide whether to stream based on stream parameter + if request.stream: + result = "" + async for chunk in response: + result += chunk + return QueryResponse(response=result) + else: + result = "" + async for chunk in response: + result += chunk + return QueryResponse(response=result) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) + async def query_text_stream(request: QueryRequest): + try: + response = await rag.aquery( # Use aquery instead of query, and add await + request.query, + param=QueryParam( + mode=request.mode, + stream=True, + only_need_context=request.only_need_context, + ), + ) + + from fastapi.responses import StreamingResponse + + async def stream_generator(): + if isinstance(response, str): + # If it's a string, send it all at once + yield f"{json.dumps({'response': response})}\n" + else: + # If it's an async generator, send chunks one by one + try: + async for chunk in response: + if chunk: # Only send non-empty content + yield f"{json.dumps({'response': chunk})}\n" + except Exception as e: + logging.error(f"Streaming error: {str(e)}") + yield f"{json.dumps({'error': str(e)})}\n" + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + "X-Accel-Buffering": "no", # Disable Nginx buffering + }, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/text", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_text(request: InsertTextRequest): + try: + await rag.ainsert(request.text) + return InsertResponse( + status="success", + message="Text successfully inserted", + document_count=1, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/file", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_file(file: UploadFile = File(...), description: str = Form(None)): + try: + content = await file.read() + + if file.filename.endswith((".txt", ".md")): + text = content.decode("utf-8") + await rag.ainsert(text) + else: + raise HTTPException( + status_code=400, + detail="Unsupported file type. Only .txt and .md files are supported", + ) + + return InsertResponse( + status="success", + message=f"File '{file.filename}' successfully inserted", + document_count=1, + ) + except UnicodeDecodeError: + raise HTTPException(status_code=400, detail="File encoding not supported") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.post( + "/documents/batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def insert_batch(files: List[UploadFile] = File(...)): + try: + inserted_count = 0 + failed_files = [] + + for file in files: + try: + content = await file.read() + if file.filename.endswith((".txt", ".md")): + text = content.decode("utf-8") + await rag.ainsert(text) + inserted_count += 1 + else: + failed_files.append(f"{file.filename} (unsupported type)") + except Exception as e: + failed_files.append(f"{file.filename} ({str(e)})") + + status_message = f"Successfully inserted {inserted_count} documents" + if failed_files: + status_message += f". Failed files: {', '.join(failed_files)}" + + return InsertResponse( + status="success" if inserted_count > 0 else "partial_success", + message=status_message, + document_count=len(files), + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.delete( + "/documents", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) + async def clear_documents(): + try: + rag.text_chunks = [] + rag.entities_vdb = None + rag.relationships_vdb = None + return InsertResponse( + status="success", + message="All documents cleared successfully", + document_count=0, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + # Ollama compatible API endpoints + @app.get("/api/version") + async def get_version(): + """Get Ollama version information""" + return OllamaVersionResponse(version="0.5.4") + + @app.get("/api/tags") + async def get_tags(): + """Get available models""" + return OllamaTagResponse( + models=[ + { + "name": LIGHTRAG_MODEL, + "model": LIGHTRAG_MODEL, + "size": LIGHTRAG_SIZE, + "digest": LIGHTRAG_DIGEST, + "modified_at": LIGHTRAG_CREATED_AT, + "details": { + "parent_model": "", + "format": "gguf", + "family": LIGHTRAG_NAME, + "families": [LIGHTRAG_NAME], + "parameter_size": "13B", + "quantization_level": "Q4_0", + }, + } + ] + ) + + def parse_query_mode(query: str) -> tuple[str, SearchMode]: + """Parse query prefix to determine search mode + Returns tuple of (cleaned_query, search_mode) + """ + mode_map = { + "/local ": SearchMode.local, + "/global ": SearchMode.global_, # global_ is used because 'global' is a Python keyword + "/naive ": SearchMode.naive, + "/hybrid ": SearchMode.hybrid, + "/mix ": SearchMode.mix, + } + + for prefix, mode in mode_map.items(): + if query.startswith(prefix): + # After removing prefix an leading spaces + cleaned_query = query[len(prefix) :].lstrip() + return cleaned_query, mode + + return query, SearchMode.hybrid + + @app.post("/api/chat") + async def chat(raw_request: Request, request: OllamaChatRequest): + """Handle chat completion requests""" + try: + # Get all messages + messages = request.messages + if not messages: + raise HTTPException(status_code=400, detail="No messages provided") + + # Get the last message as query + query = messages[-1].content + + # 解析查询模式 + cleaned_query, mode = parse_query_mode(query) + + # 开始计时 + start_time = time.time_ns() + + # 计算输入token数量 + prompt_tokens = estimate_tokens(cleaned_query) + + # 调用RAG进行查询 + query_param = QueryParam( + mode=mode, stream=request.stream, only_need_context=False + ) + + if request.stream: + from fastapi.responses import StreamingResponse + + response = await rag.aquery( # Need await to get async generator + cleaned_query, param=query_param + ) + + async def stream_generator(): + try: + first_chunk_time = None + last_chunk_time = None + total_response = "" + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = time.time_ns() + last_chunk_time = first_chunk_time + total_response = response + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": response, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + + total_response += chunk + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": chunk, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + return # Ensure the generator ends immediately after sending the completion marker + except Exception as e: + logging.error(f"Error in stream_generator: {str(e)}") + raise + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }, + ) + else: + first_chunk_time = time.time_ns() + response_text = await rag.aquery(cleaned_query, param=query_param) + last_chunk_time = time.time_ns() + + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": LIGHTRAG_MODEL, + "created_at": LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": str(response_text), + "images": None, + }, + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.get("/health", dependencies=[Depends(optional_api_key)]) + async def get_status(): + """Get current system status""" + return { + "status": "healthy", + "working_directory": str(args.working_dir), + "input_directory": str(args.input_dir), + "indexed_files": len(doc_manager.indexed_files), + "configuration": { + # LLM configuration binding/host address (if applicable)/model (if applicable) + "llm_binding": args.llm_binding, + "llm_binding_host": args.llm_binding_host, + "llm_model": args.llm_model, + # embedding model configuration binding/host address (if applicable)/model (if applicable) + "embedding_binding": args.embedding_binding, + "embedding_binding_host": args.embedding_binding_host, + "embedding_model": args.embedding_model, + "max_tokens": args.max_tokens, + }, + } + + return app + + +def main(): + args = parse_args() + import uvicorn + + app = create_app(args) + uvicorn_config = { + "app": app, + "host": args.host, + "port": args.port, + } + if args.ssl: + uvicorn_config.update( + { + "ssl_certfile": args.ssl_certfile, + "ssl_keyfile": args.ssl_keyfile, + } + ) + uvicorn.run(**uvicorn_config) + + +if __name__ == "__main__": + main() diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index 9154809c..74776828 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -1,7 +1,6 @@ aioboto3 ascii_colors fastapi -lightrag-hku nano_vectordb nest_asyncio numpy diff --git a/setup.py b/setup.py index 38eff646..b5850d26 100644 --- a/setup.py +++ b/setup.py @@ -101,6 +101,7 @@ setuptools.setup( entry_points={ "console_scripts": [ "lightrag-server=lightrag.api.lightrag_server:main [api]", + "lightrag-ollama=lightrag.api.lightrag_ollama:main [api]", ], }, ) diff --git a/start-server.sh b/start-server.sh new file mode 100755 index 00000000..4e143f37 --- /dev/null +++ b/start-server.sh @@ -0,0 +1,3 @@ +. venv/bin/activate + +lightrag-ollama --llm-binding openai --llm-model deepseek-chat --embedding-model "bge-m3:latest" --embedding-dim 1024 diff --git a/test_lightrag_ollama_chat.py b/test_lightrag_ollama_chat.py new file mode 100644 index 00000000..96aee692 --- /dev/null +++ b/test_lightrag_ollama_chat.py @@ -0,0 +1,572 @@ +""" +LightRAG Ollama Compatibility Interface Test Script + +This script tests the LightRAG's Ollama compatibility interface, including: +1. Basic functionality tests (streaming and non-streaming responses) +2. Query mode tests (local, global, naive, hybrid) +3. Error handling tests (including streaming and non-streaming scenarios) + +All responses use the JSON Lines format, complying with the Ollama API specification. +""" + +import requests +import json +import argparse +import time +from typing import Dict, Any, Optional, List, Callable +from dataclasses import dataclass, asdict +from datetime import datetime +from pathlib import Path + + +class OutputControl: + """Output control class, manages the verbosity of test output""" + + _verbose: bool = False + + @classmethod + def set_verbose(cls, verbose: bool) -> None: + cls._verbose = verbose + + @classmethod + def is_verbose(cls) -> bool: + return cls._verbose + + +@dataclass +class TestResult: + """Test result data class""" + + name: str + success: bool + duration: float + error: Optional[str] = None + timestamp: str = "" + + def __post_init__(self): + if not self.timestamp: + self.timestamp = datetime.now().isoformat() + + +class TestStats: + """Test statistics""" + + def __init__(self): + self.results: List[TestResult] = [] + self.start_time = datetime.now() + + def add_result(self, result: TestResult): + self.results.append(result) + + def export_results(self, path: str = "test_results.json"): + """Export test results to a JSON file + Args: + path: Output file path + """ + results_data = { + "start_time": self.start_time.isoformat(), + "end_time": datetime.now().isoformat(), + "results": [asdict(r) for r in self.results], + "summary": { + "total": len(self.results), + "passed": sum(1 for r in self.results if r.success), + "failed": sum(1 for r in self.results if not r.success), + "total_duration": sum(r.duration for r in self.results), + }, + } + + with open(path, "w", encoding="utf-8") as f: + json.dump(results_data, f, ensure_ascii=False, indent=2) + print(f"\nTest results saved to: {path}") + + def print_summary(self): + total = len(self.results) + passed = sum(1 for r in self.results if r.success) + failed = total - passed + duration = sum(r.duration for r in self.results) + + print("\n=== Test Summary ===") + print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"Total duration: {duration:.2f} seconds") + print(f"Total tests: {total}") + print(f"Passed: {passed}") + print(f"Failed: {failed}") + + if failed > 0: + print("\nFailed tests:") + for result in self.results: + if not result.success: + print(f"- {result.name}: {result.error}") + + +DEFAULT_CONFIG = { + "server": { + "host": "localhost", + "port": 9621, + "model": "lightrag:latest", + "timeout": 30, + "max_retries": 3, + "retry_delay": 1, + }, + "test_cases": {"basic": {"query": "唐僧有几个徒弟"}}, +} + + +def make_request( + url: str, data: Dict[str, Any], stream: bool = False +) -> requests.Response: + """Send an HTTP request with retry mechanism + Args: + url: Request URL + data: Request data + stream: Whether to use streaming response + Returns: + requests.Response: Response object + + Raises: + requests.exceptions.RequestException: Request failed after all retries + """ + server_config = CONFIG["server"] + max_retries = server_config["max_retries"] + retry_delay = server_config["retry_delay"] + timeout = server_config["timeout"] + + for attempt in range(max_retries): + try: + response = requests.post(url, json=data, stream=stream, timeout=timeout) + return response + except requests.exceptions.RequestException as e: + if attempt == max_retries - 1: # Last retry + raise + print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}") + time.sleep(retry_delay) + + +def load_config() -> Dict[str, Any]: + """Load configuration file + + First try to load from config.json in the current directory, + if it doesn't exist, use the default configuration + Returns: + Configuration dictionary + """ + config_path = Path("config.json") + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + return DEFAULT_CONFIG + + +def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None: + """Format and print JSON response data + Args: + data: Data dictionary to print + title: Title to print + indent: Number of spaces for JSON indentation + """ + if OutputControl.is_verbose(): + if title: + print(f"\n=== {title} ===") + print(json.dumps(data, ensure_ascii=False, indent=indent)) + + +# Global configuration +CONFIG = load_config() + + +def get_base_url() -> str: + """Return the base URL""" + server = CONFIG["server"] + return f"http://{server['host']}:{server['port']}/api/chat" + + +def create_request_data( + content: str, stream: bool = False, model: str = None +) -> Dict[str, Any]: + """Create basic request data + Args: + content: User message content + stream: Whether to use streaming response + model: Model name + Returns: + Dictionary containing complete request data + """ + return { + "model": model or CONFIG["server"]["model"], + "messages": [{"role": "user", "content": content}], + "stream": stream, + } + + +# Global test statistics +STATS = TestStats() + + +def run_test(func: Callable, name: str) -> None: + """Run a test and record the results + Args: + func: Test function + name: Test name + """ + start_time = time.time() + try: + func() + duration = time.time() - start_time + STATS.add_result(TestResult(name, True, duration)) + except Exception as e: + duration = time.time() - start_time + STATS.add_result(TestResult(name, False, duration, str(e))) + raise + + +def test_non_stream_chat(): + """Test non-streaming call to /api/chat endpoint""" + url = get_base_url() + data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False) + + # Send request + response = make_request(url, data) + + # Print response + if OutputControl.is_verbose(): + print("\n=== Non-streaming call response ===") + response_json = response.json() + + # Print response content + print_json_response( + {"model": response_json["model"], "message": response_json["message"]}, + "Response content", + ) + + +def test_stream_chat(): + """Test streaming call to /api/chat endpoint + + Use JSON Lines format to process streaming responses, each line is a complete JSON object. + Response format: + { + "model": "lightrag:latest", + "created_at": "2024-01-15T00:00:00Z", + "message": { + "role": "assistant", + "content": "Partial response content", + "images": null + }, + "done": false + } + + The last message will contain performance statistics, with done set to true. + """ + url = get_base_url() + data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True) + + # Send request and get streaming response + response = make_request(url, data, stream=True) + + if OutputControl.is_verbose(): + print("\n=== Streaming call response ===") + output_buffer = [] + try: + for line in response.iter_lines(): + if line: # Skip empty lines + try: + # Decode and parse JSON + data = json.loads(line.decode("utf-8")) + if data.get("done", True): # If it's the completion marker + if ( + "total_duration" in data + ): # Final performance statistics message + # print_json_response(data, "Performance statistics") + break + else: # Normal content message + message = data.get("message", {}) + content = message.get("content", "") + if content: # Only collect non-empty content + output_buffer.append(content) + print( + content, end="", flush=True + ) # Print content in real-time + except json.JSONDecodeError: + print("Error decoding JSON from response line") + finally: + response.close() # Ensure the response connection is closed + + # Print a newline + print() + + +def test_query_modes(): + """Test different query mode prefixes + + Supported query modes: + - /local: Local retrieval mode, searches only in highly relevant documents + - /global: Global retrieval mode, searches across all documents + - /naive: Naive mode, does not use any optimization strategies + - /hybrid: Hybrid mode (default), combines multiple strategies + - /mix: Mix mode + + Each mode will return responses in the same format, but with different retrieval strategies. + """ + url = get_base_url() + modes = ["local", "global", "naive", "hybrid", "mix"] + + for mode in modes: + if OutputControl.is_verbose(): + print(f"\n=== Testing /{mode} mode ===") + data = create_request_data( + f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False + ) + + # Send request + response = make_request(url, data) + response_json = response.json() + + # Print response content + print_json_response( + {"model": response_json["model"], "message": response_json["message"]} + ) + + +def create_error_test_data(error_type: str) -> Dict[str, Any]: + """Create request data for error testing + Args: + error_type: Error type, supported: + - empty_messages: Empty message list + - invalid_role: Invalid role field + - missing_content: Missing content field + + Returns: + Request dictionary containing error data + """ + error_data = { + "empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True}, + "invalid_role": { + "model": "lightrag:latest", + "messages": [{"invalid_role": "user", "content": "Test message"}], + "stream": True, + }, + "missing_content": { + "model": "lightrag:latest", + "messages": [{"role": "user"}], + "stream": True, + }, + } + return error_data.get(error_type, error_data["empty_messages"]) + + +def test_stream_error_handling(): + """Test error handling for streaming responses + + Test scenarios: + 1. Empty message list + 2. Message format error (missing required fields) + + Error responses should be returned immediately without establishing a streaming connection. + The status code should be 4xx, and detailed error information should be returned. + """ + url = get_base_url() + + if OutputControl.is_verbose(): + print("\n=== Testing streaming response error handling ===") + + # Test empty message list + if OutputControl.is_verbose(): + print("\n--- Testing empty message list (streaming) ---") + data = create_error_test_data("empty_messages") + response = make_request(url, data, stream=True) + print(f"Status code: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "Error message") + response.close() + + # Test invalid role field + if OutputControl.is_verbose(): + print("\n--- Testing invalid role field (streaming) ---") + data = create_error_test_data("invalid_role") + response = make_request(url, data, stream=True) + print(f"Status code: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "Error message") + response.close() + + # Test missing content field + if OutputControl.is_verbose(): + print("\n--- Testing missing content field (streaming) ---") + data = create_error_test_data("missing_content") + response = make_request(url, data, stream=True) + print(f"Status code: {response.status_code}") + if response.status_code != 200: + print_json_response(response.json(), "Error message") + response.close() + + +def test_error_handling(): + """Test error handling for non-streaming responses + + Test scenarios: + 1. Empty message list + 2. Message format error (missing required fields) + + Error response format: + { + "detail": "Error description" + } + + All errors should return appropriate HTTP status codes and clear error messages. + """ + url = get_base_url() + + if OutputControl.is_verbose(): + print("\n=== Testing error handling ===") + + # Test empty message list + if OutputControl.is_verbose(): + print("\n--- Testing empty message list ---") + data = create_error_test_data("empty_messages") + data["stream"] = False # Change to non-streaming mode + response = make_request(url, data) + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") + + # Test invalid role field + if OutputControl.is_verbose(): + print("\n--- Testing invalid role field ---") + data = create_error_test_data("invalid_role") + data["stream"] = False # Change to non-streaming mode + response = make_request(url, data) + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") + + # Test missing content field + if OutputControl.is_verbose(): + print("\n--- Testing missing content field ---") + data = create_error_test_data("missing_content") + data["stream"] = False # Change to non-streaming mode + response = make_request(url, data) + print(f"Status code: {response.status_code}") + print_json_response(response.json(), "Error message") + + +def get_test_cases() -> Dict[str, Callable]: + """Get all available test cases + Returns: + A dictionary mapping test names to test functions + """ + return { + "non_stream": test_non_stream_chat, + "stream": test_stream_chat, + "modes": test_query_modes, + "errors": test_error_handling, + "stream_errors": test_stream_error_handling, + } + + +def create_default_config(): + """Create a default configuration file""" + config_path = Path("config.json") + if not config_path.exists(): + with open(config_path, "w", encoding="utf-8") as f: + json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2) + print(f"Default configuration file created: {config_path}") + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description="LightRAG Ollama Compatibility Interface Testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Configuration file (config.json): + { + "server": { + "host": "localhost", # Server address + "port": 9621, # Server port + "model": "lightrag:latest" # Default model name + }, + "test_cases": { + "basic": { + "query": "Test query", # Basic query text + "stream_query": "Stream query" # Stream query text + } + } + } +""", + ) + parser.add_argument( + "-q", + "--quiet", + action="store_true", + help="Silent mode, only display test result summary", + ) + parser.add_argument( + "-a", + "--ask", + type=str, + help="Specify query content, which will override the query settings in the configuration file", + ) + parser.add_argument( + "--init-config", action="store_true", help="Create default configuration file" + ) + parser.add_argument( + "--output", + type=str, + default="", + help="Test result output file path, default is not to output to a file", + ) + parser.add_argument( + "--tests", + nargs="+", + choices=list(get_test_cases().keys()) + ["all"], + default=["all"], + help="Test cases to run, options: %(choices)s. Use 'all' to run all tests", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + # Set output mode + OutputControl.set_verbose(not args.quiet) + + # If query content is specified, update the configuration + if args.ask: + CONFIG["test_cases"]["basic"]["query"] = args.ask + + # If specified to create a configuration file + if args.init_config: + create_default_config() + exit(0) + + test_cases = get_test_cases() + + try: + if "all" in args.tests: + # Run all tests + if OutputControl.is_verbose(): + print("\n【Basic Functionality Tests】") + run_test(test_non_stream_chat, "Non-streaming Call Test") + run_test(test_stream_chat, "Streaming Call Test") + + if OutputControl.is_verbose(): + print("\n【Query Mode Tests】") + run_test(test_query_modes, "Query Mode Test") + + if OutputControl.is_verbose(): + print("\n【Error Handling Tests】") + run_test(test_error_handling, "Error Handling Test") + run_test(test_stream_error_handling, "Streaming Error Handling Test") + else: + # Run specified tests + for test_name in args.tests: + if OutputControl.is_verbose(): + print(f"\n【Running Test: {test_name}】") + run_test(test_cases[test_name], test_name) + except Exception as e: + print(f"\nAn error occurred: {str(e)}") + finally: + # Print test statistics + STATS.print_summary() + # If an output file path is specified, export the results + if args.output: + STATS.export_results(args.output)