Centralize env reading

This commit is contained in:
yangdx
2025-04-04 21:06:21 +08:00
parent bd2c528dba
commit 95630aa669
4 changed files with 40 additions and 41 deletions

View File

@@ -5,7 +5,12 @@ Configs for the LightRAG API.
import os import os
import argparse import argparse
import logging import logging
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
class OllamaServerInfos: class OllamaServerInfos:
# Constants for emulated Ollama model information # Constants for emulated Ollama model information
@@ -297,6 +302,11 @@ def parse_args() -> argparse.Namespace:
# Select Document loading tool (DOCLING, DEFAULT) # Select Document loading tool (DOCLING, DEFAULT)
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT") args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
# Add environment variables that were previously read directly
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
args.summary_language = get_env_value("SUMMARY_LANGUAGE", "en")
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
# For JWT Auth # For JWT Auth
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "") args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")

View File

@@ -167,10 +167,10 @@ def create_app(args):
app = FastAPI(**app_kwargs) app = FastAPI(**app_kwargs)
def get_cors_origins(): def get_cors_origins():
"""Get allowed origins from environment variable """Get allowed origins from global_args
Returns a list of allowed origins, defaults to ["*"] if not set Returns a list of allowed origins, defaults to ["*"] if not set
""" """
origins_str = os.getenv("CORS_ORIGINS", "*") origins_str = global_args.cors_origins
if origins_str == "*": if origins_str == "*":
return ["*"] return ["*"]
return [origin.strip() for origin in origins_str.split(",")] return [origin.strip() for origin in origins_str.split(",")]
@@ -321,6 +321,9 @@ def create_app(args):
# namespace_prefix=args.namespace_prefix, # namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
addon_params={
"language": args.summary_language
},
) )
else: # azure_openai else: # azure_openai
rag = LightRAG( rag = LightRAG(
@@ -351,6 +354,9 @@ def create_app(args):
# namespace_prefix=args.namespace_prefix, # namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
addon_params={
"language": args.summary_language
},
) )
# Add routes # Add routes

View File

@@ -7,14 +7,9 @@ import os
import sys import sys
import signal import signal
import pipmaster as pm import pipmaster as pm
from lightrag.api.utils_api import parse_args, display_splash_screen, check_env_file from lightrag.api.utils_api import display_splash_screen, check_env_file
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
from dotenv import load_dotenv from .config import global_args
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
def check_and_install_dependencies(): def check_and_install_dependencies():
@@ -59,20 +54,17 @@ def main():
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # kill command signal.signal(signal.SIGTERM, signal_handler) # kill command
# Parse all arguments using parse_args
args = parse_args(is_uvicorn_mode=False)
# Display startup information # Display startup information
display_splash_screen(args) display_splash_screen(global_args)
print("🚀 Starting LightRAG with Gunicorn") print("🚀 Starting LightRAG with Gunicorn")
print(f"🔄 Worker management: Gunicorn (workers={args.workers})") print(f"🔄 Worker management: Gunicorn (workers={global_args.workers})")
print("🔍 Preloading app: Enabled") print("🔍 Preloading app: Enabled")
print("📝 Note: Using Gunicorn's preload feature for shared data initialization") print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
print("\n\n" + "=" * 80) print("\n\n" + "=" * 80)
print("MAIN PROCESS INITIALIZATION") print("MAIN PROCESS INITIALIZATION")
print(f"Process ID: {os.getpid()}") print(f"Process ID: {os.getpid()}")
print(f"Workers setting: {args.workers}") print(f"Workers setting: {global_args.workers}")
print("=" * 80 + "\n") print("=" * 80 + "\n")
# Import Gunicorn's StandaloneApplication # Import Gunicorn's StandaloneApplication
@@ -128,31 +120,31 @@ def main():
# Set configuration variables in gunicorn_config, prioritizing command line arguments # Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = ( gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1)) global_args.workers if global_args.workers else int(os.getenv("WORKERS", 1))
) )
# Bind configuration prioritizes command line arguments # Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") host = global_args.host if global_args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) port = global_args.port if global_args.port != 9621 else int(os.getenv("PORT", 9621))
gunicorn_config.bind = f"{host}:{port}" gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments # Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = ( gunicorn_config.loglevel = (
args.log_level.lower() global_args.log_level.lower()
if args.log_level if global_args.log_level
else os.getenv("LOG_LEVEL", "info") else os.getenv("LOG_LEVEL", "info")
) )
# Timeout configuration prioritizes command line arguments # Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = ( gunicorn_config.timeout = (
args.timeout if args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2)) global_args.timeout if global_args.timeout * 2 else int(os.getenv("TIMEOUT", 150 * 2))
) )
# Keepalive configuration # Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments # SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in ( if global_args.ssl or os.getenv("SSL", "").lower() in (
"true", "true",
"1", "1",
"yes", "yes",
@@ -160,12 +152,12 @@ def main():
"on", "on",
): ):
gunicorn_config.certfile = ( gunicorn_config.certfile = (
args.ssl_certfile global_args.ssl_certfile
if args.ssl_certfile if global_args.ssl_certfile
else os.getenv("SSL_CERTFILE") else os.getenv("SSL_CERTFILE")
) )
gunicorn_config.keyfile = ( gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") global_args.ssl_keyfile if global_args.ssl_keyfile else os.getenv("SSL_KEYFILE")
) )
# Set configuration options from the module # Set configuration options from the module
@@ -190,13 +182,13 @@ def main():
# Import the application # Import the application
from lightrag.api.lightrag_server import get_application from lightrag.api.lightrag_server import get_application
return get_application(args) return get_application(global_args)
# Create the application # Create the application
app = GunicornApp("") app = GunicornApp("")
# Force workers to be an integer and greater than 1 for multi-process mode # Force workers to be an integer and greater than 1 for multi-process mode
workers_count = int(args.workers) workers_count = int(global_args.workers)
if workers_count > 1: if workers_count > 1:
# Set a flag to indicate we're in the main process # Set a flag to indicate we're in the main process
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"

View File

@@ -10,12 +10,10 @@ from ascii_colors import ASCIIColors
from lightrag.api import __api_version__ as api_version from lightrag.api import __api_version__ as api_version
from lightrag import __version__ as core_version from lightrag import __version__ as core_version
from fastapi import HTTPException, Security, Request, status from fastapi import HTTPException, Security, Request, status
from dotenv import load_dotenv
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler from .auth import auth_handler
from .config import ollama_server_infos from .config import ollama_server_infos, global_args
from ..prompt import PROMPTS
def check_env_file(): def check_env_file():
@@ -36,14 +34,8 @@ def check_env_file():
return True return True
# use the .env that is inside the current folder # Get whitelist paths from global_args, only once during initialization
# allows to use different .env file for each lightrag instance whitelist_paths = global_args.whitelist_paths.split(",")
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
# Get whitelist paths from environment variable, only once during initialization
default_whitelist = "/health,/api/*"
whitelist_paths = os.getenv("WHITELIST_PATHS", default_whitelist).split(",")
# Pre-compile path matching patterns # Pre-compile path matching patterns
whitelist_patterns: List[Tuple[str, bool]] = [] whitelist_patterns: List[Tuple[str, bool]] = []
@@ -195,7 +187,7 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.white(" ├─ Workers: ", end="") ASCIIColors.white(" ├─ Workers: ", end="")
ASCIIColors.yellow(f"{args.workers}") ASCIIColors.yellow(f"{args.workers}")
ASCIIColors.white(" ├─ CORS Origins: ", end="") ASCIIColors.white(" ├─ CORS Origins: ", end="")
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") ASCIIColors.yellow(f"{args.cors_origins}")
ASCIIColors.white(" ├─ SSL Enabled: ", end="") ASCIIColors.white(" ├─ SSL Enabled: ", end="")
ASCIIColors.yellow(f"{args.ssl}") ASCIIColors.yellow(f"{args.ssl}")
if args.ssl: if args.ssl:
@@ -252,10 +244,9 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.embedding_dim}") ASCIIColors.yellow(f"{args.embedding_dim}")
# RAG Configuration # RAG Configuration
summary_language = os.getenv("SUMMARY_LANGUAGE", PROMPTS["DEFAULT_LANGUAGE"])
ASCIIColors.magenta("\n⚙️ RAG Configuration:") ASCIIColors.magenta("\n⚙️ RAG Configuration:")
ASCIIColors.white(" ├─ Summary Language: ", end="") ASCIIColors.white(" ├─ Summary Language: ", end="")
ASCIIColors.yellow(f"{summary_language}") ASCIIColors.yellow(f"{args.summary_language}")
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="") ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
ASCIIColors.yellow(f"{args.max_parallel_insert}") ASCIIColors.yellow(f"{args.max_parallel_insert}")
ASCIIColors.white(" ├─ Max Embed Tokens: ", end="") ASCIIColors.white(" ├─ Max Embed Tokens: ", end="")