Added environment variables control of all lightrag server parameters preparing for the usage in docker

This commit is contained in:
Saifeddine ALOUI
2025-01-16 23:21:50 +01:00
parent b2e7c75f5a
commit ea566d815d
2 changed files with 91 additions and 41 deletions

View File

@@ -20,6 +20,7 @@ import os
from fastapi import Depends, Security from fastapi import Depends, Security
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm import pipmaster as pm
@@ -36,74 +37,112 @@ def get_default_host(binding_type: str) -> str:
binding_type, "http://localhost:11434" binding_type, "http://localhost:11434"
) # fallback to ollama if unknown ) # fallback to ollama if unknown
from dotenv import load_dotenv
import os
def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any:
"""
Get value from environment variable with type conversion
Args:
env_key (str): Environment variable key
default (Any): Default value if env variable is not set
value_type (type): Type to convert the value to
Returns:
Any: Converted value from environment or default
"""
value = os.getenv(env_key)
if value is None:
return default
if value_type == bool:
return value.lower() in ('true', '1', 'yes')
try:
return value_type(value)
except ValueError:
return default
def parse_args() -> argparse.Namespace:
"""
Parse command line arguments with environment variable fallback
Returns:
argparse.Namespace: Parsed arguments
"""
# Load environment variables from .env file
load_dotenv()
def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories" description="LightRAG FastAPI Server with separate working and input directories"
) )
# Start by the bindings # Bindings (with env var support)
parser.add_argument( parser.add_argument(
"--llm-binding", "--llm-binding",
default="ollama", default=get_env_value("LLM_BINDING", "ollama"),
help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)", help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
) )
parser.add_argument( parser.add_argument(
"--embedding-binding", "--embedding-binding",
default="ollama", default=get_env_value("EMBEDDING_BINDING", "ollama"),
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
) )
# Parse just these arguments first # Parse temporary args for host defaults
temp_args, _ = parser.parse_known_args() temp_args, _ = parser.parse_known_args()
# Add remaining arguments with dynamic defaults for hosts
# Server configuration # Server configuration
parser.add_argument( parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)" "--host",
default=get_env_value("HOST", "0.0.0.0"),
help="Server host (default: from env or 0.0.0.0)"
) )
parser.add_argument( parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)" "--port",
type=int,
default=get_env_value("PORT", 9621, int),
help="Server port (default: from env or 9621)"
) )
# Directory configuration # Directory configuration
parser.add_argument( parser.add_argument(
"--working-dir", "--working-dir",
default="./rag_storage", default=get_env_value("WORKING_DIR", "./rag_storage"),
help="Working directory for RAG storage (default: ./rag_storage)", help="Working directory for RAG storage (default: from env or ./rag_storage)",
) )
parser.add_argument( parser.add_argument(
"--input-dir", "--input-dir",
default="./inputs", default=get_env_value("INPUT_DIR", "./inputs"),
help="Directory containing input documents (default: ./inputs)", help="Directory containing input documents (default: from env or ./inputs)",
) )
# LLM Model configuration # LLM Model configuration
default_llm_host = get_default_host(temp_args.llm_binding) default_llm_host = get_env_value("LLM_BINDING_HOST", get_default_host(temp_args.llm_binding))
parser.add_argument( parser.add_argument(
"--llm-binding-host", "--llm-binding-host",
default=default_llm_host, default=default_llm_host,
help=f"llm server host URL (default: {default_llm_host})", help=f"llm server host URL (default: from env or {default_llm_host})",
) )
parser.add_argument( parser.add_argument(
"--llm-model", "--llm-model",
default="mistral-nemo:latest", default=get_env_value("LLM_MODEL", "mistral-nemo:latest"),
help="LLM model name (default: mistral-nemo:latest)", help="LLM model name (default: from env or mistral-nemo:latest)",
) )
# Embedding model configuration # Embedding model configuration
default_embedding_host = get_default_host(temp_args.embedding_binding) default_embedding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(temp_args.embedding_binding))
parser.add_argument( parser.add_argument(
"--embedding-binding-host", "--embedding-binding-host",
default=default_embedding_host, default=default_embedding_host,
help=f"embedding server host URL (default: {default_embedding_host})", help=f"embedding server host URL (default: from env or {default_embedding_host})",
) )
parser.add_argument( parser.add_argument(
"--embedding-model", "--embedding-model",
default="bge-m3:latest", default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"),
help="Embedding model name (default: bge-m3:latest)", help="Embedding model name (default: from env or bge-m3:latest)",
) )
def timeout_type(value): def timeout_type(value):
@@ -113,62 +152,70 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--timeout", "--timeout",
default=None, default=get_env_value("TIMEOUT", None, timeout_type),
type=timeout_type, type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout", help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
) )
# RAG configuration # RAG configuration
parser.add_argument( parser.add_argument(
"--max-async", type=int, default=4, help="Maximum async operations (default: 4)" "--max-async",
type=int,
default=get_env_value("MAX_ASYNC", 4, int),
help="Maximum async operations (default: from env or 4)"
) )
parser.add_argument( parser.add_argument(
"--max-tokens", "--max-tokens",
type=int, type=int,
default=32768, default=get_env_value("MAX_TOKENS", 32768, int),
help="Maximum token size (default: 32768)", help="Maximum token size (default: from env or 32768)",
) )
parser.add_argument( parser.add_argument(
"--embedding-dim", "--embedding-dim",
type=int, type=int,
default=1024, default=get_env_value("EMBEDDING_DIM", 1024, int),
help="Embedding dimensions (default: 1024)", help="Embedding dimensions (default: from env or 1024)",
) )
parser.add_argument( parser.add_argument(
"--max-embed-tokens", "--max-embed-tokens",
type=int, type=int,
default=8192, default=get_env_value("MAX_EMBED_TOKENS", 8192, int),
help="Maximum embedding token size (default: 8192)", help="Maximum embedding token size (default: from env or 8192)",
) )
# Logging configuration # Logging configuration
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
default="INFO", default=get_env_value("LOG_LEVEL", "INFO"),
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)", help="Logging level (default: from env or INFO)",
) )
parser.add_argument( parser.add_argument(
"--key", "--key",
type=str, type=str,
default=get_env_value("LIGHTRAG_API_KEY", None),
help="API key for authentication. This protects lightrag server against unauthorized access", help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
) )
# Optional https parameters # Optional https parameters
parser.add_argument( parser.add_argument(
"--ssl", action="store_true", help="Enable HTTPS (default: False)" "--ssl",
action="store_true",
default=get_env_value("SSL", False, bool),
help="Enable HTTPS (default: from env or False)"
) )
parser.add_argument( parser.add_argument(
"--ssl-certfile", "--ssl-certfile",
default=None, default=get_env_value("SSL_CERTFILE", None),
help="Path to SSL certificate file (required if --ssl is enabled)", help="Path to SSL certificate file (required if --ssl is enabled)",
) )
parser.add_argument( parser.add_argument(
"--ssl-keyfile", "--ssl-keyfile",
default=None, default=get_env_value("SSL_KEYFILE", None),
help="Path to SSL private key file (required if --ssl is enabled)", help="Path to SSL private key file (required if --ssl is enabled)",
) )
return parser.parse_args() return parser.parse_args()
@@ -435,9 +482,11 @@ def create_app(args):
else: else:
logging.warning(f"No content extracted from file: {file_path}") logging.warning(f"No content extracted from file: {file_path}")
@app.on_event("startup")
async def startup_event(): @asynccontextmanager
"""Index all files in input directory during startup""" async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Startup logic
try: try:
new_files = doc_manager.scan_directory() new_files = doc_manager.scan_directory()
for file_path in new_files: for file_path in new_files:
@@ -448,7 +497,6 @@ def create_app(args):
logging.error(f"Error indexing file {file_path}: {str(e)}") logging.error(f"Error indexing file {file_path}: {str(e)}")
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}") logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
except Exception as e: except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}") logging.error(f"Error during startup indexing: {str(e)}")
@@ -521,6 +569,7 @@ def create_app(args):
else: else:
return QueryResponse(response=response) return QueryResponse(response=response)
except Exception as e: except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)]) @app.post("/query/stream", dependencies=[Depends(optional_api_key)])

View File

@@ -16,3 +16,4 @@ torch
tqdm tqdm
transformers transformers
uvicorn uvicorn
python-dotenv