From ea566d815dfed438cbb1f2bd16745278ecae9070 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Thu, 16 Jan 2025 23:21:50 +0100 Subject: [PATCH] Added environment variables control of all lightrag server parameters preparing for the usage in docker --- lightrag/api/lightrag_server.py | 131 ++++++++++++++++++++++---------- lightrag/api/requirements.txt | 1 + 2 files changed, 91 insertions(+), 41 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 0d154b38..cec3c089 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -20,6 +20,7 @@ import os from fastapi import Depends, Security from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm @@ -36,74 +37,112 @@ def get_default_host(binding_type: str) -> str: binding_type, "http://localhost:11434" ) # fallback to ollama if unknown +from dotenv import load_dotenv +import os -def parse_args(): +def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any: + """ + Get value from environment variable with type conversion + + Args: + env_key (str): Environment variable key + default (Any): Default value if env variable is not set + value_type (type): Type to convert the value to + + Returns: + Any: Converted value from environment or default + """ + value = os.getenv(env_key) + if value is None: + return default + + if value_type == bool: + return value.lower() in ('true', '1', 'yes') + try: + return value_type(value) + except ValueError: + return default + +def parse_args() -> argparse.Namespace: + """ + Parse command line arguments with environment variable fallback + + Returns: + argparse.Namespace: Parsed arguments + """ + # Load environment variables from .env file + load_dotenv() + parser = argparse.ArgumentParser( description="LightRAG FastAPI Server with separate working and input directories" ) - # Start by the bindings + # Bindings (with env var support) parser.add_argument( "--llm-binding", - default="ollama", - help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)", + default=get_env_value("LLM_BINDING", "ollama"), + help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", ) parser.add_argument( "--embedding-binding", - default="ollama", - help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", + default=get_env_value("EMBEDDING_BINDING", "ollama"), + help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)", ) - # Parse just these arguments first + # Parse temporary args for host defaults temp_args, _ = parser.parse_known_args() - # Add remaining arguments with dynamic defaults for hosts # Server configuration parser.add_argument( - "--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" + "--host", + default=get_env_value("HOST", "0.0.0.0"), + help="Server host (default: from env or 0.0.0.0)" ) parser.add_argument( - "--port", type=int, default=9621, help="Server port (default: 9621)" + "--port", + type=int, + default=get_env_value("PORT", 9621, int), + help="Server port (default: from env or 9621)" ) # Directory configuration parser.add_argument( "--working-dir", - default="./rag_storage", - help="Working directory for RAG storage (default: ./rag_storage)", + default=get_env_value("WORKING_DIR", "./rag_storage"), + help="Working directory for RAG storage (default: from env or ./rag_storage)", ) parser.add_argument( "--input-dir", - default="./inputs", - help="Directory containing input documents (default: ./inputs)", + default=get_env_value("INPUT_DIR", "./inputs"), + help="Directory containing input documents (default: from env or ./inputs)", ) # LLM Model configuration - default_llm_host = get_default_host(temp_args.llm_binding) + default_llm_host = get_env_value("LLM_BINDING_HOST", get_default_host(temp_args.llm_binding)) parser.add_argument( "--llm-binding-host", default=default_llm_host, - help=f"llm server host URL (default: {default_llm_host})", + help=f"llm server host URL (default: from env or {default_llm_host})", ) parser.add_argument( "--llm-model", - default="mistral-nemo:latest", - help="LLM model name (default: mistral-nemo:latest)", + default=get_env_value("LLM_MODEL", "mistral-nemo:latest"), + help="LLM model name (default: from env or mistral-nemo:latest)", ) # Embedding model configuration - default_embedding_host = get_default_host(temp_args.embedding_binding) + default_embedding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(temp_args.embedding_binding)) parser.add_argument( "--embedding-binding-host", default=default_embedding_host, - help=f"embedding server host URL (default: {default_embedding_host})", + help=f"embedding server host URL (default: from env or {default_embedding_host})", ) parser.add_argument( "--embedding-model", - default="bge-m3:latest", - help="Embedding model name (default: bge-m3:latest)", + default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"), + help="Embedding model name (default: from env or bge-m3:latest)", ) def timeout_type(value): @@ -113,62 +152,70 @@ def parse_args(): parser.add_argument( "--timeout", - default=None, + default=get_env_value("TIMEOUT", None, timeout_type), type=timeout_type, help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", ) + # RAG configuration parser.add_argument( - "--max-async", type=int, default=4, help="Maximum async operations (default: 4)" + "--max-async", + type=int, + default=get_env_value("MAX_ASYNC", 4, int), + help="Maximum async operations (default: from env or 4)" ) parser.add_argument( "--max-tokens", type=int, - default=32768, - help="Maximum token size (default: 32768)", + default=get_env_value("MAX_TOKENS", 32768, int), + help="Maximum token size (default: from env or 32768)", ) parser.add_argument( "--embedding-dim", type=int, - default=1024, - help="Embedding dimensions (default: 1024)", + default=get_env_value("EMBEDDING_DIM", 1024, int), + help="Embedding dimensions (default: from env or 1024)", ) parser.add_argument( "--max-embed-tokens", type=int, - default=8192, - help="Maximum embedding token size (default: 8192)", + default=get_env_value("MAX_EMBED_TOKENS", 8192, int), + help="Maximum embedding token size (default: from env or 8192)", ) # Logging configuration parser.add_argument( "--log-level", - default="INFO", + default=get_env_value("LOG_LEVEL", "INFO"), choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level (default: INFO)", + help="Logging level (default: from env or INFO)", ) parser.add_argument( "--key", type=str, + default=get_env_value("LIGHTRAG_API_KEY", None), help="API key for authentication. This protects lightrag server against unauthorized access", - default=None, ) # Optional https parameters parser.add_argument( - "--ssl", action="store_true", help="Enable HTTPS (default: False)" + "--ssl", + action="store_true", + default=get_env_value("SSL", False, bool), + help="Enable HTTPS (default: from env or False)" ) parser.add_argument( "--ssl-certfile", - default=None, + default=get_env_value("SSL_CERTFILE", None), help="Path to SSL certificate file (required if --ssl is enabled)", ) parser.add_argument( "--ssl-keyfile", - default=None, + default=get_env_value("SSL_KEYFILE", None), help="Path to SSL private key file (required if --ssl is enabled)", ) + return parser.parse_args() @@ -434,10 +481,12 @@ def create_app(args): logging.info(f"Successfully indexed file: {file_path}") else: logging.warning(f"No content extracted from file: {file_path}") - - @app.on_event("startup") - async def startup_event(): - """Index all files in input directory during startup""" + + + @asynccontextmanager + async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events""" + # Startup logic try: new_files = doc_manager.scan_directory() for file_path in new_files: @@ -448,7 +497,6 @@ def create_app(args): logging.error(f"Error indexing file {file_path}: {str(e)}") logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") - except Exception as e: logging.error(f"Error during startup indexing: {str(e)}") @@ -521,6 +569,7 @@ def create_app(args): else: return QueryResponse(response=response) except Exception as e: + trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) diff --git a/lightrag/api/requirements.txt b/lightrag/api/requirements.txt index 9154809c..0d5b82f6 100644 --- a/lightrag/api/requirements.txt +++ b/lightrag/api/requirements.txt @@ -16,3 +16,4 @@ torch tqdm transformers uvicorn +python-dotenv \ No newline at end of file