Added environment variables control of all lightrag server parameters preparing for the usage in docker

This commit is contained in:
Saifeddine ALOUI
2025-01-16 23:21:50 +01:00
parent b2e7c75f5a
commit ea566d815d
2 changed files with 91 additions and 41 deletions

View File

@@ -20,6 +20,7 @@ import os
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from starlette.status import HTTP_403_FORBIDDEN
import pipmaster as pm
@@ -36,74 +37,112 @@ def get_default_host(binding_type: str) -> str:
binding_type, "http://localhost:11434"
) # fallback to ollama if unknown
from dotenv import load_dotenv
import os
def parse_args():
def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any:
"""
Get value from environment variable with type conversion
Args:
env_key (str): Environment variable key
default (Any): Default value if env variable is not set
value_type (type): Type to convert the value to
Returns:
Any: Converted value from environment or default
"""
value = os.getenv(env_key)
if value is None:
return default
if value_type == bool:
return value.lower() in ('true', '1', 'yes')
try:
return value_type(value)
except ValueError:
return default
def parse_args() -> argparse.Namespace:
"""
Parse command line arguments with environment variable fallback
Returns:
argparse.Namespace: Parsed arguments
"""
# Load environment variables from .env file
load_dotenv()
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
# Start by the bindings
# Bindings (with env var support)
parser.add_argument(
"--llm-binding",
default="ollama",
help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)",
default=get_env_value("LLM_BINDING", "ollama"),
help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
)
parser.add_argument(
"--embedding-binding",
default="ollama",
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)",
default=get_env_value("EMBEDDING_BINDING", "ollama"),
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
)
# Parse just these arguments first
# Parse temporary args for host defaults
temp_args, _ = parser.parse_known_args()
# Add remaining arguments with dynamic defaults for hosts
# Server configuration
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
"--host",
default=get_env_value("HOST", "0.0.0.0"),
help="Server host (default: from env or 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=9621, help="Server port (default: 9621)"
"--port",
type=int,
default=get_env_value("PORT", 9621, int),
help="Server port (default: from env or 9621)"
)
# Directory configuration
parser.add_argument(
"--working-dir",
default="./rag_storage",
help="Working directory for RAG storage (default: ./rag_storage)",
default=get_env_value("WORKING_DIR", "./rag_storage"),
help="Working directory for RAG storage (default: from env or ./rag_storage)",
)
parser.add_argument(
"--input-dir",
default="./inputs",
help="Directory containing input documents (default: ./inputs)",
default=get_env_value("INPUT_DIR", "./inputs"),
help="Directory containing input documents (default: from env or ./inputs)",
)
# LLM Model configuration
default_llm_host = get_default_host(temp_args.llm_binding)
default_llm_host = get_env_value("LLM_BINDING_HOST", get_default_host(temp_args.llm_binding))
parser.add_argument(
"--llm-binding-host",
default=default_llm_host,
help=f"llm server host URL (default: {default_llm_host})",
help=f"llm server host URL (default: from env or {default_llm_host})",
)
parser.add_argument(
"--llm-model",
default="mistral-nemo:latest",
help="LLM model name (default: mistral-nemo:latest)",
default=get_env_value("LLM_MODEL", "mistral-nemo:latest"),
help="LLM model name (default: from env or mistral-nemo:latest)",
)
# Embedding model configuration
default_embedding_host = get_default_host(temp_args.embedding_binding)
default_embedding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(temp_args.embedding_binding))
parser.add_argument(
"--embedding-binding-host",
default=default_embedding_host,
help=f"embedding server host URL (default: {default_embedding_host})",
help=f"embedding server host URL (default: from env or {default_embedding_host})",
)
parser.add_argument(
"--embedding-model",
default="bge-m3:latest",
help="Embedding model name (default: bge-m3:latest)",
default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"),
help="Embedding model name (default: from env or bge-m3:latest)",
)
def timeout_type(value):
@@ -113,62 +152,70 @@ def parse_args():
parser.add_argument(
"--timeout",
default=None,
default=get_env_value("TIMEOUT", None, timeout_type),
type=timeout_type,
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
)
# RAG configuration
parser.add_argument(
"--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
"--max-async",
type=int,
default=get_env_value("MAX_ASYNC", 4, int),
help="Maximum async operations (default: from env or 4)"
)
parser.add_argument(
"--max-tokens",
type=int,
default=32768,
help="Maximum token size (default: 32768)",
default=get_env_value("MAX_TOKENS", 32768, int),
help="Maximum token size (default: from env or 32768)",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=1024,
help="Embedding dimensions (default: 1024)",
default=get_env_value("EMBEDDING_DIM", 1024, int),
help="Embedding dimensions (default: from env or 1024)",
)
parser.add_argument(
"--max-embed-tokens",
type=int,
default=8192,
help="Maximum embedding token size (default: 8192)",
default=get_env_value("MAX_EMBED_TOKENS", 8192, int),
help="Maximum embedding token size (default: from env or 8192)",
)
# Logging configuration
parser.add_argument(
"--log-level",
default="INFO",
default=get_env_value("LOG_LEVEL", "INFO"),
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level (default: INFO)",
help="Logging level (default: from env or INFO)",
)
parser.add_argument(
"--key",
type=str,
default=get_env_value("LIGHTRAG_API_KEY", None),
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)
# Optional https parameters
parser.add_argument(
"--ssl", action="store_true", help="Enable HTTPS (default: False)"
"--ssl",
action="store_true",
default=get_env_value("SSL", False, bool),
help="Enable HTTPS (default: from env or False)"
)
parser.add_argument(
"--ssl-certfile",
default=None,
default=get_env_value("SSL_CERTFILE", None),
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=None,
default=get_env_value("SSL_KEYFILE", None),
help="Path to SSL private key file (required if --ssl is enabled)",
)
return parser.parse_args()
@@ -434,10 +481,12 @@ def create_app(args):
logging.info(f"Successfully indexed file: {file_path}")
else:
logging.warning(f"No content extracted from file: {file_path}")
@app.on_event("startup")
async def startup_event():
"""Index all files in input directory during startup"""
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Startup logic
try:
new_files = doc_manager.scan_directory()
for file_path in new_files:
@@ -448,7 +497,6 @@ def create_app(args):
logging.error(f"Error indexing file {file_path}: {str(e)}")
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")
@@ -521,6 +569,7 @@ def create_app(args):
else:
return QueryResponse(response=response)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])