From 2297007b7b240e17fa77e4dc5aad228a8b0a1b65 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 20:30:58 +0100 Subject: [PATCH 1/7] Simplified the api services issue #565 --- README.md | 133 ++--- lightrag/api/azure_openai_lightrag_server.py | 532 ------------------ ..._lightrag_server.py => lightrag_server.py} | 94 +++- lightrag/api/ollama_lightrag_server.py | 491 ---------------- lightrag/api/openai_lightrag_server.py | 506 ----------------- setup.py | 5 +- 6 files changed, 136 insertions(+), 1625 deletions(-) delete mode 100644 lightrag/api/azure_openai_lightrag_server.py rename lightrag/api/{lollms_lightrag_server.py => lightrag_server.py} (82%) delete mode 100644 lightrag/api/ollama_lightrag_server.py delete mode 100644 lightrag/api/openai_lightrag_server.py diff --git a/README.md b/README.md index 6c981d92..278f6a72 100644 --- a/README.md +++ b/README.md @@ -912,12 +912,14 @@ pip install -e ".[api]" ### Prerequisites -Before running any of the servers, ensure you have the corresponding backend service running: +Before running any of the servers, ensure you have the corresponding backend service running for both llm and embedding. +The new api allows you to mix different bindings for llm/embeddings. +For example, you have the possibility to use ollama for the embedding and openai for the llm. #### For LoLLMs Server - LoLLMs must be running and accessible - Default connection: http://localhost:9600 -- Configure using --lollms-host if running on a different host/port +- Configure using --llm-binding-host and/or --embedding-binding-host if running on a different host/port #### For Ollama Server - Ollama must be running and accessible @@ -953,15 +955,19 @@ The output of the last command will give you the endpoint and the key for the Op Each server has its own specific configuration options: -#### LoLLMs Server Options +#### LightRag Server Options | Parameter | Default | Description | |-----------|---------|-------------| | --host | 0.0.0.0 | RAG server host | | --port | 9621 | RAG server port | +| --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai (default: ollama) | +| --llm-binding-host | http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai | llm server host URL (default: http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai) | | --model | mistral-nemo:latest | LLM model name | +| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama) | +| --embedding-binding-host | http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai | embedding server host URL (default: http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai) | | --embedding-model | bge-m3:latest | Embedding model name | -| --lollms-host | http://localhost:9600 | LoLLMS backend URL | +| --embedding-binding-host | http://localhost:9600 | LoLLMS backend URL | | --working-dir | ./rag_storage | Working directory for RAG | | --max-async | 4 | Maximum async operations | | --max-tokens | 32768 | Maximum token size | @@ -971,95 +977,71 @@ Each server has its own specific configuration options: | --log-level | INFO | Logging level | | --key | none | Access Key to protect the lightrag service | -#### Ollama Server Options - -| Parameter | Default | Description | -|-----------|---------|-------------| -| --host | 0.0.0.0 | RAG server host | -| --port | 9621 | RAG server port | -| --model | mistral-nemo:latest | LLM model name | -| --embedding-model | bge-m3:latest | Embedding model name | -| --ollama-host | http://localhost:11434 | Ollama backend URL | -| --working-dir | ./rag_storage | Working directory for RAG | -| --max-async | 4 | Maximum async operations | -| --max-tokens | 32768 | Maximum token size | -| --embedding-dim | 1024 | Embedding dimensions | -| --max-embed-tokens | 8192 | Maximum embedding token size | -| --input-file | ./book.txt | Initial input file | -| --log-level | INFO | Logging level | -| --key | none | Access Key to protect the lightrag service | - -#### OpenAI Server Options - -| Parameter | Default | Description | -|-----------|---------|-------------| -| --host | 0.0.0.0 | RAG server host | -| --port | 9621 | RAG server port | -| --model | gpt-4 | OpenAI model name | -| --embedding-model | text-embedding-3-large | OpenAI embedding model | -| --working-dir | ./rag_storage | Working directory for RAG | -| --max-tokens | 32768 | Maximum token size | -| --max-embed-tokens | 8192 | Maximum embedding token size | -| --input-dir | ./inputs | Input directory for documents | -| --log-level | INFO | Logging level | -| --key | none | Access Key to protect the lightrag service | - -#### OpenAI AZURE Server Options - -| Parameter | Default | Description | -|-----------|---------|-------------| -| --host | 0.0.0.0 | Server host | -| --port | 9621 | Server port | -| --model | gpt-4 | OpenAI model name | -| --embedding-model | text-embedding-3-large | OpenAI embedding model | -| --working-dir | ./rag_storage | Working directory for RAG | -| --max-tokens | 32768 | Maximum token size | -| --max-embed-tokens | 8192 | Maximum embedding token size | -| --input-dir | ./inputs | Input directory for documents | -| --enable-cache | True | Enable response cache | -| --log-level | INFO | Logging level | -| --key | none | Access Key to protect the lightrag service | For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`. ### Example Usage -#### LoLLMs RAG Server +#### Running a Lightrag server with ollama default local server as llm and embedding backends + +Ollama is the default backend for both llm and embedding, so by default you can run lightrag-server with no parameters and the default ones will be used. Make sure ollama is installed and is running and default models are already installed on ollama. ```bash -# Custom configuration with specific model and working directory -lollms-lightrag-server --model mistral-nemo --port 8080 --working-dir ./custom_rag +# Run lightrag with ollama, mistral-nemo:latest for llm, and bge-m3:latest for embedding +lightrag-server -# Using specific models (ensure they are installed in your LoLLMs instance) -lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 +# Using specific models (ensure they are installed in your ollama instance) +lightrag-server --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-model nomic-embed-text --embedding-dim 1024 -# Using specific models and an authentication key -lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 --key ky-mykey +# Using an authentication key +lightrag-server --key my-key +# Using lollms for llm and ollama for embedding +lightrag-server --llm-binding lollms ``` -#### Ollama RAG Server +#### Running a Lightrag server with lollms default local server as llm and embedding backends ```bash -# Custom configuration with specific model and working directory -ollama-lightrag-server --model mistral-nemo:latest --port 8080 --working-dir ./custom_rag +# Run lightrag with lollms, mistral-nemo:latest for llm, and bge-m3:latest for embedding, use lollms for both llm and embedding +lightrag-server --llm-binding lollms --embedding-binding lollms -# Using specific models (ensure they are installed in your Ollama instance) -ollama-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 +# Using specific models (ensure they are installed in your ollama instance) +lightrag-server --llm-binding lollms --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-binding lollms --embedding-model nomic-embed-text --embedding-dim 1024 + +# Using an authentication key +lightrag-server --key my-key + +# Using lollms for llm and openai for embedding +lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small ``` -#### OpenAI RAG Server + +#### Running a Lightrag server with openai server as llm and embedding backends ```bash -# Using GPT-4 with text-embedding-3-large -openai-lightrag-server --port 9624 --model gpt-4 --embedding-model text-embedding-3-large -``` -#### Azure OpenAI RAG Server -```bash -# Using GPT-4 with text-embedding-3-large -azure-openai-lightrag-server --model gpt-4o --port 8080 --working-dir ./custom_rag --embedding-model text-embedding-3-large +# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding +lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small + +# Using an authentication key +lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small --key my-key + +# Using lollms for llm and openai for embedding +lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small ``` +#### Running a Lightrag server with azure openai server as llm and embedding backends + +```bash +# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding +lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small + +# Using an authentication key +lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding azure_openai --embedding-model text-embedding-3-small --key my-key + +# Using lollms for llm and azure_openai for embedding +lightrag-server --llm-binding lollms --embedding-binding azure_openai --embedding-model text-embedding-3-small +``` **Important Notes:** - For LoLLMs: Make sure the specified models are installed in your LoLLMs instance @@ -1069,10 +1051,7 @@ azure-openai-lightrag-server --model gpt-4o --port 8080 --working-dir ./custom_r For help on any server, use the --help flag: ```bash -lollms-lightrag-server --help -ollama-lightrag-server --help -openai-lightrag-server --help -azure-openai-lightrag-server --help +lightrag-server --help ``` Note: If you don't need the API functionality, you can install the base package without API support using: @@ -1092,7 +1071,7 @@ Query the RAG system with options for different search modes. ```bash curl -X POST "http://localhost:9621/query" \ -H "Content-Type: application/json" \ - -d '{"query": "Your question here", "mode": "hybrid"}' + -d '{"query": "Your question here", "mode": "hybrid", ""}' ``` #### POST /query/stream diff --git a/lightrag/api/azure_openai_lightrag_server.py b/lightrag/api/azure_openai_lightrag_server.py deleted file mode 100644 index abe3f738..00000000 --- a/lightrag/api/azure_openai_lightrag_server.py +++ /dev/null @@ -1,532 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form -from pydantic import BaseModel -import asyncio -import logging -import argparse -from lightrag import LightRAG, QueryParam -from lightrag.llm import ( - azure_openai_complete_if_cache, - azure_openai_embedding, -) -from lightrag.utils import EmbeddingFunc -from typing import Optional, List -from enum import Enum -from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception -import os -from dotenv import load_dotenv -import inspect -import json -from fastapi.responses import StreamingResponse - -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader -from fastapi.middleware.cors import CORSMiddleware - -from starlette.status import HTTP_403_FORBIDDEN - -load_dotenv() - -AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") -AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") -AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") -AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") - -AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT") -AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION") - - -def parse_args(): - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with OpenAI integration" - ) - - # Server configuration - parser.add_argument( - "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=9621, help="Server port (default: 9621)" - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default="./rag_storage", - help="Working directory for RAG storage (default: ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default="./inputs", - help="Directory containing input documents (default: ./inputs)", - ) - - # Model configuration - parser.add_argument( - "--model", default="gpt-4o", help="OpenAI model name (default: gpt-4o)" - ) - parser.add_argument( - "--embedding-model", - default="text-embedding-3-large", - help="OpenAI embedding model (default: text-embedding-3-large)", - ) - - # RAG configuration - parser.add_argument( - "--max-tokens", - type=int, - default=32768, - help="Maximum token size (default: 32768)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=8192, - help="Maximum embedding token size (default: 8192)", - ) - parser.add_argument( - "--enable-cache", - default=True, - help="Enable response cache (default: True)", - ) - # Logging configuration - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: INFO)", - ) - - parser.add_argument( - "--key", - type=str, - help="API key for authentication. This protects lightrag server against unauthorized access", - default=None, - ) - - return parser.parse_args() - - -class DocumentManager: - """Handles document operations and tracking""" - - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -# Pydantic models -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" - hybrid = "hybrid" - - -class QueryRequest(BaseModel): - query: str - mode: SearchMode = SearchMode.hybrid - only_need_context: bool = False - # stream: bool = False - - -class QueryResponse(BaseModel): - response: str - - -class InsertTextRequest(BaseModel): - text: str - description: Optional[str] = None - - -class InsertResponse(BaseModel): - status: str - message: str - document_count: int - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -async def get_embedding_dim(embedding_model: str) -> int: - """Get embedding dimensions for the specified model""" - test_text = ["This is a test sentence."] - embedding = await azure_openai_embedding(test_text, model=embedding_model) - return embedding.shape[1] - - -def create_app(args): - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - - # Check if API key is provided either through env var or args - api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - - # Initialize FastAPI - app = FastAPI( - title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories" - + "(With authentication)" - if api_key - else "", - version="1.0.0", - openapi_tags=[{"name": "api"}], - ) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create the optional API key dependency - optional_api_key = get_api_key_dependency(api_key) - - # Create working directory if it doesn't exist - Path(args.working_dir).mkdir(parents=True, exist_ok=True) - - # Initialize document manager - doc_manager = DocumentManager(args.input_dir) - - # Get embedding dimensions - embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model)) - - async def async_openai_complete( - prompt, system_prompt=None, history_messages=[], **kwargs - ): - """Async wrapper for OpenAI completion""" - kwargs.pop("keyword_extraction", None) - - return await azure_openai_complete_if_cache( - args.model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - base_url=AZURE_OPENAI_ENDPOINT, - api_key=AZURE_OPENAI_API_KEY, - api_version=AZURE_OPENAI_API_VERSION, - **kwargs, - ) - - # Initialize RAG with OpenAI configuration - rag = LightRAG( - enable_llm_cache=args.enable_cache, - working_dir=args.working_dir, - llm_model_func=async_openai_complete, - llm_model_name=args.model, - llm_model_max_token_size=args.max_tokens, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: azure_openai_embedding( - texts, model=args.embedding_model - ), - ), - ) - - @app.on_event("startup") - async def startup_event(): - """Index all files in input directory during startup""" - try: - new_files = doc_manager.scan_directory() - for file_path in new_files: - try: - # Use async file reading - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() - # Use the async version of insert directly - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Indexed file: {file_path}") - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") - - logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") - - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(): - """Manually trigger scanning for new documents""" - try: - new_files = doc_manager.scan_directory() - indexed_count = 0 - - for file_path in new_files: - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - indexed_count += 1 - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - return { - "status": "success", - "indexed_count": indexed_count, - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/resetcache", dependencies=[Depends(optional_api_key)]) - async def reset_cache(): - """Manually reset cache""" - try: - cachefile = args.working_dir + "/kv_store_llm_response_cache.json" - if os.path.exists(cachefile): - with open(cachefile, "w") as f: - f.write("{}") - return {"status": "success"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): - """Upload a file to the input directory""" - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Immediately index the uploaded file - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=False, - only_need_context=request.only_need_context, - ), - ) - return QueryResponse(response=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - ), - ) - if inspect.isasyncgen(response): - - async def stream_generator(): - async for chunk in response: - yield json.dumps({"data": chunk}) + "\n" - - return StreamingResponse( - stream_generator(), media_type="application/json" - ) - else: - return QueryResponse(response=response) - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - try: - await rag.ainsert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=1, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - try: - content = await file.read() - - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - rag.insert(text) - else: - raise HTTPException( - status_code=400, - detail="Unsupported file type. Only .txt and .md files are supported", - ) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = await file.read() - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - rag.insert(text) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - - status_message = f"Successfully inserted {inserted_count} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status="success" if inserted_count > 0 else "partial_success", - message=status_message, - document_count=len(files), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.get("/health", dependencies=[Depends(optional_api_key)]) - async def get_status(): - """Get current system status""" - return { - "status": "healthy", - "working_directory": str(args.working_dir), - "input_directory": str(args.input_dir), - "indexed_files": len(doc_manager.indexed_files), - "configuration": { - "model": args.model, - "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, - "embedding_dim": embedding_dim, - }, - } - - return app - - -def main(): - args = parse_args() - import uvicorn - - app = create_app(args) - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/lightrag/api/lollms_lightrag_server.py b/lightrag/api/lightrag_server.py similarity index 82% rename from lightrag/api/lollms_lightrag_server.py rename to lightrag/api/lightrag_server.py index 8a2804a0..4f8e38cd 100644 --- a/lightrag/api/lollms_lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -4,6 +4,10 @@ import logging import argparse from lightrag import LightRAG, QueryParam from lightrag.llm import lollms_model_complete, lollms_embed +from lightrag.llm import ollama_model_complete, ollama_embed +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding + from lightrag.utils import EmbeddingFunc from typing import Optional, List from enum import Enum @@ -19,12 +23,36 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN +def get_default_host(binding_type: str) -> str: + default_hosts = { + "ollama": "http://localhost:11434", + "lollms": "http://localhost:9600", + "azure_openai": "https://api.openai.com/v1", + "openai": "https://api.openai.com/v1" + } + return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown def parse_args(): parser = argparse.ArgumentParser( description="LightRAG FastAPI Server with separate working and input directories" ) + #Start by the bindings + parser.add_argument( + "--llm-binding", + default="ollama", + help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)", + ) + parser.add_argument( + "--embedding-binding", + default="ollama", + help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", + ) + + # Parse just these arguments first + temp_args, _ = parser.parse_known_args() + + # Add remaining arguments with dynamic defaults for hosts # Server configuration parser.add_argument( "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" @@ -45,22 +73,33 @@ def parse_args(): help="Directory containing input documents (default: ./inputs)", ) - # Model configuration + # LLM Model configuration + default_llm_host = get_default_host(temp_args.llm_binding) parser.add_argument( - "--model", + "--llm-binding-host", + default=default_llm_host, + help=f"llm server host URL (default: {default_llm_host})", + ) + + parser.add_argument( + "--llm-model", default="mistral-nemo:latest", help="LLM model name (default: mistral-nemo:latest)", ) + + # Embedding model configuration + default_embedding_host = get_default_host(temp_args.embedding_binding) + parser.add_argument( + "--embedding-binding-host", + default=default_embedding_host, + help=f"embedding server host URL (default: {default_embedding_host})", + ) + parser.add_argument( "--embedding-model", default="bge-m3:latest", help="Embedding model name (default: bge-m3:latest)", ) - parser.add_argument( - "--lollms-host", - default="http://localhost:9600", - help="lollms host URL (default: http://localhost:9600)", - ) # RAG configuration parser.add_argument( @@ -188,6 +227,15 @@ def get_api_key_dependency(api_key: Optional[str]): def create_app(args): + # Verify that bindings arer correctly setup + if args.llm_binding not in ["lollms", "ollama", "openai"]: + raise Exception("llm binding not supported") + + if args.embedding_binding not in ["lollms", "ollama", "openai"]: + raise Exception("embedding binding not supported") + + + # Setup logging logging.basicConfig( format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) @@ -203,7 +251,7 @@ def create_app(args): + "(With authentication)" if api_key else "", - version="1.0.0", + version="1.0.1", openapi_tags=[{"name": "api"}], ) @@ -225,23 +273,32 @@ def create_app(args): # Initialize document manager doc_manager = DocumentManager(args.input_dir) + + # Initialize RAG rag = LightRAG( working_dir=args.working_dir, - llm_model_func=lollms_model_complete, - llm_model_name=args.model, + llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache, + llm_model_name=args.llm_model, llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ - "host": args.lollms_host, + "host": args.llm_binding_host, "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( embedding_dim=args.embedding_dim, max_token_size=args.max_embed_tokens, func=lambda texts: lollms_embed( - texts, embed_model=args.embedding_model, host=args.lollms_host - ), + texts, embed_model=args.embedding_model, host=args.embedding_binding_host + ) if args.llm_binding=="lollms" else ollama_embed( + texts, embed_model=args.embedding_model, host=args.embedding_binding_host + ) if args.llm_binding=="ollama" else azure_openai_embedding( + texts, model=args.embedding_model # no host is used for openai + ) if args.llm_binding=="azure_openai" else openai_embedding( + texts, model=args.embedding_model # no host is used for openai + ) + ), ) @@ -470,10 +527,17 @@ def create_app(args): "input_directory": str(args.input_dir), "indexed_files": len(doc_manager.indexed_files), "configuration": { - "model": args.model, + # LLM configuration binding/host address (if applicable)/model (if applicable) + "llm_binding": args.llm_binding, + "llm_binding_host": args.llm_binding_host, + "llm_model": args.llm_model, + + # embedding model configuration binding/host address (if applicable)/model (if applicable) + "embedding_binding": args.embedding_binding, + "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, + "max_tokens": args.max_tokens, - "lollms_host": args.lollms_host, }, } diff --git a/lightrag/api/ollama_lightrag_server.py b/lightrag/api/ollama_lightrag_server.py deleted file mode 100644 index b3140aba..00000000 --- a/lightrag/api/ollama_lightrag_server.py +++ /dev/null @@ -1,491 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form -from pydantic import BaseModel -import logging -import argparse -from lightrag import LightRAG, QueryParam -from lightrag.llm import ollama_model_complete, ollama_embed -from lightrag.utils import EmbeddingFunc -from typing import Optional, List -from enum import Enum -from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception -import os - -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader -from fastapi.middleware.cors import CORSMiddleware - -from starlette.status import HTTP_403_FORBIDDEN - - -def parse_args(): - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with separate working and input directories" - ) - - # Server configuration - parser.add_argument( - "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=9621, help="Server port (default: 9621)" - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default="./rag_storage", - help="Working directory for RAG storage (default: ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default="./inputs", - help="Directory containing input documents (default: ./inputs)", - ) - - # Model configuration - parser.add_argument( - "--model", - default="mistral-nemo:latest", - help="LLM model name (default: mistral-nemo:latest)", - ) - parser.add_argument( - "--embedding-model", - default="bge-m3:latest", - help="Embedding model name (default: bge-m3:latest)", - ) - parser.add_argument( - "--ollama-host", - default="http://localhost:11434", - help="Ollama host URL (default: http://localhost:11434)", - ) - - # RAG configuration - parser.add_argument( - "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" - ) - parser.add_argument( - "--max-tokens", - type=int, - default=32768, - help="Maximum token size (default: 32768)", - ) - parser.add_argument( - "--embedding-dim", - type=int, - default=1024, - help="Embedding dimensions (default: 1024)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=8192, - help="Maximum embedding token size (default: 8192)", - ) - - # Logging configuration - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: INFO)", - ) - parser.add_argument( - "--key", - type=str, - help="API key for authentication. This protects lightrag server against unauthorized access", - default=None, - ) - - return parser.parse_args() - - -class DocumentManager: - """Handles document operations and tracking""" - - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -# Pydantic models -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" - hybrid = "hybrid" - - -class QueryRequest(BaseModel): - query: str - mode: SearchMode = SearchMode.hybrid - stream: bool = False - only_need_context: bool = False - - -class QueryResponse(BaseModel): - response: str - - -class InsertTextRequest(BaseModel): - text: str - description: Optional[str] = None - - -class InsertResponse(BaseModel): - status: str - message: str - document_count: int - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -def create_app(args): - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - - # Check if API key is provided either through env var or args - api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - - # Initialize FastAPI - app = FastAPI( - title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories" - + "(With authentication)" - if api_key - else "", - version="1.0.0", - openapi_tags=[{"name": "api"}], - ) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create the optional API key dependency - optional_api_key = get_api_key_dependency(api_key) - - # Create working directory if it doesn't exist - Path(args.working_dir).mkdir(parents=True, exist_ok=True) - - # Initialize document manager - doc_manager = DocumentManager(args.input_dir) - - # Initialize RAG - rag = LightRAG( - working_dir=args.working_dir, - llm_model_func=ollama_model_complete, - llm_model_name=args.model, - llm_model_max_async=args.max_async, - llm_model_max_token_size=args.max_tokens, - llm_model_kwargs={ - "host": args.ollama_host, - "options": {"num_ctx": args.max_tokens}, - }, - embedding_func=EmbeddingFunc( - embedding_dim=args.embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: ollama_embed( - texts, embed_model=args.embedding_model, host=args.ollama_host - ), - ), - ) - - @app.on_event("startup") - async def startup_event(): - """Index all files in input directory during startup""" - try: - new_files = doc_manager.scan_directory() - for file_path in new_files: - try: - # Use async file reading - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() - # Use the async version of insert directly - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Indexed file: {file_path}") - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") - - logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") - - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(): - """Manually trigger scanning for new documents""" - try: - new_files = doc_manager.scan_directory() - indexed_count = 0 - - for file_path in new_files: - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - indexed_count += 1 - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - return { - "status": "success", - "indexed_count": indexed_count, - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): - """Upload a file to the input directory""" - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Immediately index the uploaded file - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=request.stream, - only_need_context=request.only_need_context, - ), - ) - - if request.stream: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) - else: - return QueryResponse(response=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - try: - response = rag.query( - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - ), - ) - - async def stream_generator(): - async for chunk in response: - yield chunk - - return stream_generator() - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - try: - await rag.ainsert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=len(rag), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - try: - content = await file.read() - - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - await rag.ainsert(text) - else: - raise HTTPException( - status_code=400, - detail="Unsupported file type. Only .txt and .md files are supported", - ) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = await file.read() - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - await rag.ainsert(text) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - - status_message = f"Successfully inserted {inserted_count} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status="success" if inserted_count > 0 else "partial_success", - message=status_message, - document_count=len(files), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.get("/health", dependencies=[Depends(optional_api_key)]) - async def get_status(): - """Get current system status""" - return { - "status": "healthy", - "working_directory": str(args.working_dir), - "input_directory": str(args.input_dir), - "indexed_files": len(doc_manager.indexed_files), - "configuration": { - "model": args.model, - "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, - "ollama_host": args.ollama_host, - }, - } - - return app - - -def main(): - args = parse_args() - import uvicorn - - app = create_app(args) - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/lightrag/api/openai_lightrag_server.py b/lightrag/api/openai_lightrag_server.py deleted file mode 100644 index 349c09da..00000000 --- a/lightrag/api/openai_lightrag_server.py +++ /dev/null @@ -1,506 +0,0 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form -from pydantic import BaseModel -import asyncio -import logging -import argparse -from lightrag import LightRAG, QueryParam -from lightrag.llm import openai_complete_if_cache, openai_embedding -from lightrag.utils import EmbeddingFunc -from typing import Optional, List -from enum import Enum -from pathlib import Path -import shutil -import aiofiles -from ascii_colors import trace_exception -import nest_asyncio - -import os - -from fastapi import Depends, Security -from fastapi.security import APIKeyHeader -from fastapi.middleware.cors import CORSMiddleware - -from starlette.status import HTTP_403_FORBIDDEN - -# Apply nest_asyncio to solve event loop issues -nest_asyncio.apply() - - -def parse_args(): - parser = argparse.ArgumentParser( - description="LightRAG FastAPI Server with OpenAI integration" - ) - - # Server configuration - parser.add_argument( - "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=9621, help="Server port (default: 9621)" - ) - - # Directory configuration - parser.add_argument( - "--working-dir", - default="./rag_storage", - help="Working directory for RAG storage (default: ./rag_storage)", - ) - parser.add_argument( - "--input-dir", - default="./inputs", - help="Directory containing input documents (default: ./inputs)", - ) - - # Model configuration - parser.add_argument( - "--model", default="gpt-4", help="OpenAI model name (default: gpt-4)" - ) - parser.add_argument( - "--embedding-model", - default="text-embedding-3-large", - help="OpenAI embedding model (default: text-embedding-3-large)", - ) - - # RAG configuration - parser.add_argument( - "--max-tokens", - type=int, - default=32768, - help="Maximum token size (default: 32768)", - ) - parser.add_argument( - "--max-embed-tokens", - type=int, - default=8192, - help="Maximum embedding token size (default: 8192)", - ) - - # Logging configuration - parser.add_argument( - "--log-level", - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: INFO)", - ) - - parser.add_argument( - "--key", - type=str, - help="API key for authentication. This protects lightrag server against unauthorized access", - default=None, - ) - - return parser.parse_args() - - -class DocumentManager: - """Handles document operations and tracking""" - - def __init__(self, input_dir: str, supported_extensions: tuple = (".txt", ".md")): - self.input_dir = Path(input_dir) - self.supported_extensions = supported_extensions - self.indexed_files = set() - - # Create input directory if it doesn't exist - self.input_dir.mkdir(parents=True, exist_ok=True) - - def scan_directory(self) -> List[Path]: - """Scan input directory for new files""" - new_files = [] - for ext in self.supported_extensions: - for file_path in self.input_dir.rglob(f"*{ext}"): - if file_path not in self.indexed_files: - new_files.append(file_path) - return new_files - - def mark_as_indexed(self, file_path: Path): - """Mark a file as indexed""" - self.indexed_files.add(file_path) - - def is_supported_file(self, filename: str) -> bool: - """Check if file type is supported""" - return any(filename.lower().endswith(ext) for ext in self.supported_extensions) - - -# Pydantic models -class SearchMode(str, Enum): - naive = "naive" - local = "local" - global_ = "global" - hybrid = "hybrid" - - -class QueryRequest(BaseModel): - query: str - mode: SearchMode = SearchMode.hybrid - stream: bool = False - only_need_context: bool = False - - -class QueryResponse(BaseModel): - response: str - - -class InsertTextRequest(BaseModel): - text: str - description: Optional[str] = None - - -class InsertResponse(BaseModel): - status: str - message: str - document_count: int - - -def get_api_key_dependency(api_key: Optional[str]): - if not api_key: - # If no API key is configured, return a dummy dependency that always succeeds - async def no_auth(): - return None - - return no_auth - - # If API key is configured, use proper authentication - api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - - async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): - if not api_key_header_value: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="API Key required" - ) - if api_key_header_value != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" - ) - return api_key_header_value - - return api_key_auth - - -async def get_embedding_dim(embedding_model: str) -> int: - """Get embedding dimensions for the specified model""" - test_text = ["This is a test sentence."] - embedding = await openai_embedding(test_text, model=embedding_model) - return embedding.shape[1] - - -def create_app(args): - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - - # Check if API key is provided either through env var or args - api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - - # Initialize FastAPI - app = FastAPI( - title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories" - + "(With authentication)" - if api_key - else "", - version="1.0.0", - openapi_tags=[{"name": "api"}], - ) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create the optional API key dependency - optional_api_key = get_api_key_dependency(api_key) - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # Create working directory if it doesn't exist - Path(args.working_dir).mkdir(parents=True, exist_ok=True) - - # Initialize document manager - doc_manager = DocumentManager(args.input_dir) - - # Get embedding dimensions - embedding_dim = asyncio.run(get_embedding_dim(args.embedding_model)) - - async def async_openai_complete( - prompt, system_prompt=None, history_messages=[], **kwargs - ): - """Async wrapper for OpenAI completion""" - return await openai_complete_if_cache( - args.model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - # Initialize RAG with OpenAI configuration - rag = LightRAG( - working_dir=args.working_dir, - llm_model_func=async_openai_complete, - llm_model_name=args.model, - llm_model_max_token_size=args.max_tokens, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dim, - max_token_size=args.max_embed_tokens, - func=lambda texts: openai_embedding(texts, model=args.embedding_model), - ), - ) - - @app.on_event("startup") - async def startup_event(): - """Index all files in input directory during startup""" - try: - new_files = doc_manager.scan_directory() - for file_path in new_files: - try: - # Use async file reading - async with aiofiles.open(file_path, "r", encoding="utf-8") as f: - content = await f.read() - # Use the async version of insert directly - await rag.ainsert(content) - doc_manager.mark_as_indexed(file_path) - logging.info(f"Indexed file: {file_path}") - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") - - logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") - - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - - @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(): - """Manually trigger scanning for new documents""" - try: - new_files = doc_manager.scan_directory() - indexed_count = 0 - - for file_path in new_files: - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - rag.insert(content) - doc_manager.mark_as_indexed(file_path) - indexed_count += 1 - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") - - return { - "status": "success", - "indexed_count": indexed_count, - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) - async def upload_to_input_dir(file: UploadFile = File(...)): - """Upload a file to the input directory""" - try: - if not doc_manager.is_supported_file(file.filename): - raise HTTPException( - status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", - ) - - file_path = doc_manager.input_dir / file.filename - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - # Immediately index the uploaded file - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - rag.insert(content) - doc_manager.mark_as_indexed(file_path) - - return { - "status": "success", - "message": f"File uploaded and indexed: {file.filename}", - "total_documents": len(doc_manager.indexed_files), - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] - ) - async def query_text(request: QueryRequest): - try: - response = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - stream=request.stream, - only_need_context=request.only_need_context, - ), - ) - - if request.stream: - result = "" - async for chunk in response: - result += chunk - return QueryResponse(response=result) - else: - return QueryResponse(response=response) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) - async def query_text_stream(request: QueryRequest): - try: - response = rag.query( - request.query, - param=QueryParam( - mode=request.mode, - stream=True, - only_need_context=request.only_need_context, - ), - ) - - async def stream_generator(): - async for chunk in response: - yield chunk - - return stream_generator() - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/text", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_text(request: InsertTextRequest): - try: - rag.insert(request.text) - return InsertResponse( - status="success", - message="Text successfully inserted", - document_count=len(rag), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/file", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_file(file: UploadFile = File(...), description: str = Form(None)): - try: - content = await file.read() - - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - rag.insert(text) - else: - raise HTTPException( - status_code=400, - detail="Unsupported file type. Only .txt and .md files are supported", - ) - - return InsertResponse( - status="success", - message=f"File '{file.filename}' successfully inserted", - document_count=1, - ) - except UnicodeDecodeError: - raise HTTPException(status_code=400, detail="File encoding not supported") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.post( - "/documents/batch", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def insert_batch(files: List[UploadFile] = File(...)): - try: - inserted_count = 0 - failed_files = [] - - for file in files: - try: - content = await file.read() - if file.filename.endswith((".txt", ".md")): - text = content.decode("utf-8") - rag.insert(text) - inserted_count += 1 - else: - failed_files.append(f"{file.filename} (unsupported type)") - except Exception as e: - failed_files.append(f"{file.filename} ({str(e)})") - - status_message = f"Successfully inserted {inserted_count} documents" - if failed_files: - status_message += f". Failed files: {', '.join(failed_files)}" - - return InsertResponse( - status="success" if inserted_count > 0 else "partial_success", - message=status_message, - document_count=len(files), - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.delete( - "/documents", - response_model=InsertResponse, - dependencies=[Depends(optional_api_key)], - ) - async def clear_documents(): - try: - rag.text_chunks = [] - rag.entities_vdb = None - rag.relationships_vdb = None - return InsertResponse( - status="success", - message="All documents cleared successfully", - document_count=0, - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @app.get("/health", dependencies=[Depends(optional_api_key)]) - async def get_status(): - """Get current system status""" - return { - "status": "healthy", - "working_directory": str(args.working_dir), - "input_directory": str(args.input_dir), - "indexed_files": len(doc_manager.indexed_files), - "configuration": { - "model": args.model, - "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, - "embedding_dim": embedding_dim, - }, - } - - return app - - -def main(): - args = parse_args() - import uvicorn - - app = create_app(args) - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main() diff --git a/setup.py b/setup.py index 368610f6..38eff646 100644 --- a/setup.py +++ b/setup.py @@ -100,10 +100,7 @@ setuptools.setup( }, entry_points={ "console_scripts": [ - "lollms-lightrag-server=lightrag.api.lollms_lightrag_server:main [api]", - "ollama-lightrag-server=lightrag.api.ollama_lightrag_server:main [api]", - "openai-lightrag-server=lightrag.api.openai_lightrag_server:main [api]", - "azure-openai-lightrag-server=lightrag.api.azure_openai_lightrag_server:main [api]", + "lightrag-server=lightrag.api.lightrag_server:main [api]", ], }, ) From adb288c5bb7ac6753daf898cb99c1458f9663773 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 21:39:25 +0100 Subject: [PATCH 2/7] added timeout --- lightrag/api/lightrag_server.py | 23 +++++++++++++++++++++++ lightrag/llm.py | 4 +++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 4f8e38cd..1175afab 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -101,6 +101,12 @@ def parse_args(): help="Embedding model name (default: bge-m3:latest)", ) + parser.add_argument( + "--timeout", + default=300, + help="Timeout is seconds (useful when using slow AI)", + ) + # RAG configuration parser.add_argument( "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" @@ -139,6 +145,22 @@ def parse_args(): default=None, ) + # Optional https parameters + parser.add_argument( + "--ssl", + action="store_true", + help="Enable HTTPS (default: False)" + ) + parser.add_argument( + "--ssl-certfile", + default=None, + help="Path to SSL certificate file (required if --ssl is enabled)" + ) + parser.add_argument( + "--ssl-keyfile", + default=None, + help="Path to SSL private key file (required if --ssl is enabled)" + ) return parser.parse_args() @@ -284,6 +306,7 @@ def create_app(args): llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ "host": args.llm_binding_host, + "timeout":args.timeout "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( diff --git a/lightrag/llm.py b/lightrag/llm.py index 0c17019a..4e01dd51 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -336,6 +336,7 @@ async def hf_model_if_cache( (RateLimitError, APIConnectionError, APITimeoutError) ), ) + async def ollama_model_if_cache( model, prompt, @@ -406,8 +407,9 @@ async def lollms_model_if_cache( full_prompt += prompt request_data["prompt"] = full_prompt + timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", 300)) # 300 seconds = 5 minutes - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=timeout) as session: if stream: async def inner(): From ab3cc3f0f47790ea3a713b2a3831aac1efb3c854 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 21:39:41 +0100 Subject: [PATCH 3/7] fixed missing coma --- lightrag/api/lightrag_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1175afab..d4cddd6c 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -306,7 +306,7 @@ def create_app(args): llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ "host": args.llm_binding_host, - "timeout":args.timeout + "timeout":args.timeout, "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( From a619b010640a356d95e241ff07e17717ee4c2fe1 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 22:17:13 +0100 Subject: [PATCH 4/7] Next test of timeout --- lightrag/api/lightrag_server.py | 11 ++++++++--- lightrag/llm.py | 3 +-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index d4cddd6c..40b63463 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -101,12 +101,17 @@ def parse_args(): help="Embedding model name (default: bge-m3:latest)", ) + def timeout_type(value): + if value is None or value == "None": + return None + return int(value) + parser.add_argument( "--timeout", - default=300, - help="Timeout is seconds (useful when using slow AI)", + default=None, + type=timeout_type, + help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", ) - # RAG configuration parser.add_argument( "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" diff --git a/lightrag/llm.py b/lightrag/llm.py index 4e01dd51..7a51d025 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -407,11 +407,10 @@ async def lollms_model_if_cache( full_prompt += prompt request_data["prompt"] = full_prompt - timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", 300)) # 300 seconds = 5 minutes + timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None)) async with aiohttp.ClientSession(timeout=timeout) as session: if stream: - async def inner(): async with session.post( f"{base_url}/lollms_generate", json=request_data From e21fbef60b702e3205dbdc3d92c680bc7d2b90c5 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 10 Jan 2025 22:38:57 +0100 Subject: [PATCH 5/7] updated documlentation --- README.md | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 278f6a72..57aee435 100644 --- a/README.md +++ b/README.md @@ -959,23 +959,26 @@ Each server has its own specific configuration options: | Parameter | Default | Description | |-----------|---------|-------------| -| --host | 0.0.0.0 | RAG server host | -| --port | 9621 | RAG server port | -| --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai (default: ollama) | -| --llm-binding-host | http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai | llm server host URL (default: http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai) | -| --model | mistral-nemo:latest | LLM model name | -| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama) | -| --embedding-binding-host | http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai | embedding server host URL (default: http://localhost:11434 if the binding is ollama, http://localhost:9600 if the binding is lollms, https://api.openai.com/v1 if the binding is openai) | +| --host | 0.0.0.0 | Server host | +| --port | 9621 | Server port | +| --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai | +| --llm-binding-host | (dynamic) | LLM server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) | +| --llm-model | mistral-nemo:latest | LLM model name | +| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai | +| --embedding-binding-host | (dynamic) | Embedding server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) | | --embedding-model | bge-m3:latest | Embedding model name | -| --embedding-binding-host | http://localhost:9600 | LoLLMS backend URL | -| --working-dir | ./rag_storage | Working directory for RAG | +| --working-dir | ./rag_storage | Working directory for RAG storage | +| --input-dir | ./inputs | Directory containing input documents | | --max-async | 4 | Maximum async operations | | --max-tokens | 32768 | Maximum token size | | --embedding-dim | 1024 | Embedding dimensions | | --max-embed-tokens | 8192 | Maximum embedding token size | -| --input-file | ./book.txt | Initial input file | -| --log-level | INFO | Logging level | -| --key | none | Access Key to protect the lightrag service | +| --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout | +| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | +| --key | None | API key for authentication. Protects lightrag server against unauthorized access | +| --ssl | False | Enable HTTPS | +| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) | +| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | From e0e656ab014138c129aab9f48e7f3f8bcf6b57b7 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sat, 11 Jan 2025 01:35:49 +0100 Subject: [PATCH 6/7] Added ssl support --- lightrag/api/lightrag_server.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 40b63463..1f88e776 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -262,7 +262,16 @@ def create_app(args): raise Exception("embedding binding not supported") - + # Add SSL validation + if args.ssl: + if not args.ssl_certfile or not args.ssl_keyfile: + raise Exception("SSL certificate and key files must be provided when SSL is enabled") + if not os.path.exists(args.ssl_certfile): + raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") + if not os.path.exists(args.ssl_keyfile): + raise Exception(f"SSL key file not found: {args.ssl_keyfile}") + + # Setup logging logging.basicConfig( format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) @@ -577,7 +586,17 @@ def main(): import uvicorn app = create_app(args) - uvicorn.run(app, host=args.host, port=args.port) + uvicorn_config = { + "app": app, + "host": args.host, + "port": args.port, + } + if args.ssl: + uvicorn_config.update({ + "ssl_certfile": args.ssl_certfile, + "ssl_keyfile": args.ssl_keyfile, + }) + uvicorn.run(**uvicorn_config) if __name__ == "__main__": From 224fce9b1b1a887a998d2ee818f0855c950422de Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sat, 11 Jan 2025 01:37:07 +0100 Subject: [PATCH 7/7] run precommit to fix linting issues --- lightrag/api/lightrag_server.py | 83 ++++++++++++++++++++------------- lightrag/llm.py | 2 +- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1f88e776..644e622d 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -23,21 +23,25 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN + def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": "http://localhost:11434", "lollms": "http://localhost:9600", "azure_openai": "https://api.openai.com/v1", - "openai": "https://api.openai.com/v1" + "openai": "https://api.openai.com/v1", } - return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown + return default_hosts.get( + binding_type, "http://localhost:11434" + ) # fallback to ollama if unknown + def parse_args(): parser = argparse.ArgumentParser( description="LightRAG FastAPI Server with separate working and input directories" ) - #Start by the bindings + # Start by the bindings parser.add_argument( "--llm-binding", default="ollama", @@ -48,7 +52,7 @@ def parse_args(): default="ollama", help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", ) - + # Parse just these arguments first temp_args, _ = parser.parse_known_args() @@ -152,19 +156,17 @@ def parse_args(): # Optional https parameters parser.add_argument( - "--ssl", - action="store_true", - help="Enable HTTPS (default: False)" + "--ssl", action="store_true", help="Enable HTTPS (default: False)" ) parser.add_argument( "--ssl-certfile", default=None, - help="Path to SSL certificate file (required if --ssl is enabled)" + help="Path to SSL certificate file (required if --ssl is enabled)", ) parser.add_argument( "--ssl-keyfile", - default=None, - help="Path to SSL private key file (required if --ssl is enabled)" + default=None, + help="Path to SSL private key file (required if --ssl is enabled)", ) return parser.parse_args() @@ -261,17 +263,17 @@ def create_app(args): if args.embedding_binding not in ["lollms", "ollama", "openai"]: raise Exception("embedding binding not supported") - # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception("SSL certificate and key files must be provided when SSL is enabled") + raise Exception( + "SSL certificate and key files must be provided when SSL is enabled" + ) if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): raise Exception(f"SSL key file not found: {args.ssl_keyfile}") - - + # Setup logging logging.basicConfig( format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) @@ -309,33 +311,48 @@ def create_app(args): # Initialize document manager doc_manager = DocumentManager(args.input_dir) - - # Initialize RAG rag = LightRAG( working_dir=args.working_dir, - llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache, + llm_model_func=lollms_model_complete + if args.llm_binding == "lollms" + else ollama_model_complete + if args.llm_binding == "ollama" + else azure_openai_complete_if_cache + if args.llm_binding == "azure_openai" + else openai_complete_if_cache, llm_model_name=args.llm_model, llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ "host": args.llm_binding_host, - "timeout":args.timeout, + "timeout": args.timeout, "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( embedding_dim=args.embedding_dim, max_token_size=args.max_embed_tokens, func=lambda texts: lollms_embed( - texts, embed_model=args.embedding_model, host=args.embedding_binding_host - ) if args.llm_binding=="lollms" else ollama_embed( - texts, embed_model=args.embedding_model, host=args.embedding_binding_host - ) if args.llm_binding=="ollama" else azure_openai_embedding( - texts, model=args.embedding_model # no host is used for openai - ) if args.llm_binding=="azure_openai" else openai_embedding( - texts, model=args.embedding_model # no host is used for openai + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, ) - + if args.llm_binding == "lollms" + else ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.llm_binding == "ollama" + else azure_openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ) + if args.llm_binding == "azure_openai" + else openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ), ), ) @@ -568,12 +585,10 @@ def create_app(args): "llm_binding": args.llm_binding, "llm_binding_host": args.llm_binding_host, "llm_model": args.llm_model, - # embedding model configuration binding/host address (if applicable)/model (if applicable) "embedding_binding": args.embedding_binding, "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, }, } @@ -590,12 +605,14 @@ def main(): "app": app, "host": args.host, "port": args.port, - } + } if args.ssl: - uvicorn_config.update({ - "ssl_certfile": args.ssl_certfile, - "ssl_keyfile": args.ssl_keyfile, - }) + uvicorn_config.update( + { + "ssl_certfile": args.ssl_certfile, + "ssl_keyfile": args.ssl_keyfile, + } + ) uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm.py b/lightrag/llm.py index 7a51d025..c49ed138 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -336,7 +336,6 @@ async def hf_model_if_cache( (RateLimitError, APIConnectionError, APITimeoutError) ), ) - async def ollama_model_if_cache( model, prompt, @@ -411,6 +410,7 @@ async def lollms_model_if_cache( async with aiohttp.ClientSession(timeout=timeout) as session: if stream: + async def inner(): async with session.post( f"{base_url}/lollms_generate", json=request_data