Added environment variables control of all lightrag server parameters preparing for the usage in docker
This commit is contained in:
@@ -20,6 +20,7 @@ import os
|
|||||||
from fastapi import Depends, Security
|
from fastapi import Depends, Security
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
@@ -36,74 +37,112 @@ def get_default_host(binding_type: str) -> str:
|
|||||||
binding_type, "http://localhost:11434"
|
binding_type, "http://localhost:11434"
|
||||||
) # fallback to ollama if unknown
|
) # fallback to ollama if unknown
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any:
|
||||||
|
"""
|
||||||
|
Get value from environment variable with type conversion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_key (str): Environment variable key
|
||||||
|
default (Any): Default value if env variable is not set
|
||||||
|
value_type (type): Type to convert the value to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: Converted value from environment or default
|
||||||
|
"""
|
||||||
|
value = os.getenv(env_key)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
if value_type == bool:
|
||||||
|
return value.lower() in ('true', '1', 'yes')
|
||||||
|
try:
|
||||||
|
return value_type(value)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
"""
|
||||||
|
Parse command line arguments with environment variable fallback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.Namespace: Parsed arguments
|
||||||
|
"""
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="LightRAG FastAPI Server with separate working and input directories"
|
description="LightRAG FastAPI Server with separate working and input directories"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start by the bindings
|
# Bindings (with env var support)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm-binding",
|
"--llm-binding",
|
||||||
default="ollama",
|
default=get_env_value("LLM_BINDING", "ollama"),
|
||||||
help="LLM binding to be used. Supported: lollms, ollama, openai (default: ollama)",
|
help="LLM binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--embedding-binding",
|
"--embedding-binding",
|
||||||
default="ollama",
|
default=get_env_value("EMBEDDING_BINDING", "ollama"),
|
||||||
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)",
|
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: from env or ollama)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse just these arguments first
|
# Parse temporary args for host defaults
|
||||||
temp_args, _ = parser.parse_known_args()
|
temp_args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
# Add remaining arguments with dynamic defaults for hosts
|
|
||||||
# Server configuration
|
# Server configuration
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)"
|
"--host",
|
||||||
|
default=get_env_value("HOST", "0.0.0.0"),
|
||||||
|
help="Server host (default: from env or 0.0.0.0)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=9621, help="Server port (default: 9621)"
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=get_env_value("PORT", 9621, int),
|
||||||
|
help="Server port (default: from env or 9621)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Directory configuration
|
# Directory configuration
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--working-dir",
|
"--working-dir",
|
||||||
default="./rag_storage",
|
default=get_env_value("WORKING_DIR", "./rag_storage"),
|
||||||
help="Working directory for RAG storage (default: ./rag_storage)",
|
help="Working directory for RAG storage (default: from env or ./rag_storage)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input-dir",
|
"--input-dir",
|
||||||
default="./inputs",
|
default=get_env_value("INPUT_DIR", "./inputs"),
|
||||||
help="Directory containing input documents (default: ./inputs)",
|
help="Directory containing input documents (default: from env or ./inputs)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# LLM Model configuration
|
# LLM Model configuration
|
||||||
default_llm_host = get_default_host(temp_args.llm_binding)
|
default_llm_host = get_env_value("LLM_BINDING_HOST", get_default_host(temp_args.llm_binding))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm-binding-host",
|
"--llm-binding-host",
|
||||||
default=default_llm_host,
|
default=default_llm_host,
|
||||||
help=f"llm server host URL (default: {default_llm_host})",
|
help=f"llm server host URL (default: from env or {default_llm_host})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm-model",
|
"--llm-model",
|
||||||
default="mistral-nemo:latest",
|
default=get_env_value("LLM_MODEL", "mistral-nemo:latest"),
|
||||||
help="LLM model name (default: mistral-nemo:latest)",
|
help="LLM model name (default: from env or mistral-nemo:latest)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Embedding model configuration
|
# Embedding model configuration
|
||||||
default_embedding_host = get_default_host(temp_args.embedding_binding)
|
default_embedding_host = get_env_value("EMBEDDING_BINDING_HOST", get_default_host(temp_args.embedding_binding))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--embedding-binding-host",
|
"--embedding-binding-host",
|
||||||
default=default_embedding_host,
|
default=default_embedding_host,
|
||||||
help=f"embedding server host URL (default: {default_embedding_host})",
|
help=f"embedding server host URL (default: from env or {default_embedding_host})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
default="bge-m3:latest",
|
default=get_env_value("EMBEDDING_MODEL", "bge-m3:latest"),
|
||||||
help="Embedding model name (default: bge-m3:latest)",
|
help="Embedding model name (default: from env or bge-m3:latest)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def timeout_type(value):
|
def timeout_type(value):
|
||||||
@@ -113,62 +152,70 @@ def parse_args():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--timeout",
|
"--timeout",
|
||||||
default=None,
|
default=get_env_value("TIMEOUT", None, timeout_type),
|
||||||
type=timeout_type,
|
type=timeout_type,
|
||||||
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
help="Timeout in seconds (useful when using slow AI). Use None for infinite timeout",
|
||||||
)
|
)
|
||||||
|
|
||||||
# RAG configuration
|
# RAG configuration
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-async", type=int, default=4, help="Maximum async operations (default: 4)"
|
"--max-async",
|
||||||
|
type=int,
|
||||||
|
default=get_env_value("MAX_ASYNC", 4, int),
|
||||||
|
help="Maximum async operations (default: from env or 4)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
default=32768,
|
default=get_env_value("MAX_TOKENS", 32768, int),
|
||||||
help="Maximum token size (default: 32768)",
|
help="Maximum token size (default: from env or 32768)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--embedding-dim",
|
"--embedding-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
default=get_env_value("EMBEDDING_DIM", 1024, int),
|
||||||
help="Embedding dimensions (default: 1024)",
|
help="Embedding dimensions (default: from env or 1024)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-embed-tokens",
|
"--max-embed-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
default=8192,
|
default=get_env_value("MAX_EMBED_TOKENS", 8192, int),
|
||||||
help="Maximum embedding token size (default: 8192)",
|
help="Maximum embedding token size (default: from env or 8192)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Logging configuration
|
# Logging configuration
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
default="INFO",
|
default=get_env_value("LOG_LEVEL", "INFO"),
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
help="Logging level (default: INFO)",
|
help="Logging level (default: from env or INFO)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--key",
|
"--key",
|
||||||
type=str,
|
type=str,
|
||||||
|
default=get_env_value("LIGHTRAG_API_KEY", None),
|
||||||
help="API key for authentication. This protects lightrag server against unauthorized access",
|
help="API key for authentication. This protects lightrag server against unauthorized access",
|
||||||
default=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional https parameters
|
# Optional https parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssl", action="store_true", help="Enable HTTPS (default: False)"
|
"--ssl",
|
||||||
|
action="store_true",
|
||||||
|
default=get_env_value("SSL", False, bool),
|
||||||
|
help="Enable HTTPS (default: from env or False)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssl-certfile",
|
"--ssl-certfile",
|
||||||
default=None,
|
default=get_env_value("SSL_CERTFILE", None),
|
||||||
help="Path to SSL certificate file (required if --ssl is enabled)",
|
help="Path to SSL certificate file (required if --ssl is enabled)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssl-keyfile",
|
"--ssl-keyfile",
|
||||||
default=None,
|
default=get_env_value("SSL_KEYFILE", None),
|
||||||
help="Path to SSL private key file (required if --ssl is enabled)",
|
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -435,9 +482,11 @@ def create_app(args):
|
|||||||
else:
|
else:
|
||||||
logging.warning(f"No content extracted from file: {file_path}")
|
logging.warning(f"No content extracted from file: {file_path}")
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
@asynccontextmanager
|
||||||
"""Index all files in input directory during startup"""
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Lifespan context manager for startup and shutdown events"""
|
||||||
|
# Startup logic
|
||||||
try:
|
try:
|
||||||
new_files = doc_manager.scan_directory()
|
new_files = doc_manager.scan_directory()
|
||||||
for file_path in new_files:
|
for file_path in new_files:
|
||||||
@@ -448,7 +497,6 @@ def create_app(args):
|
|||||||
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
||||||
|
|
||||||
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
|
logging.info(f"Indexed {len(new_files)} documents from {args.input_dir}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during startup indexing: {str(e)}")
|
logging.error(f"Error during startup indexing: {str(e)}")
|
||||||
|
|
||||||
@@ -521,6 +569,7 @@ def create_app(args):
|
|||||||
else:
|
else:
|
||||||
return QueryResponse(response=response)
|
return QueryResponse(response=response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
||||||
|
@@ -16,3 +16,4 @@ torch
|
|||||||
tqdm
|
tqdm
|
||||||
transformers
|
transformers
|
||||||
uvicorn
|
uvicorn
|
||||||
|
python-dotenv
|
Reference in New Issue
Block a user