diff --git a/.gitignore b/.gitignore index 3eb55bd3..6deb14d5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ site/ # Logs / Reports *.log +*.log.* *.logfire *.coverage/ log/ diff --git a/.env.example b/env.example similarity index 92% rename from .env.example rename to env.example index e4034def..112676c6 100644 --- a/.env.example +++ b/env.example @@ -1,6 +1,9 @@ +### This is sample file of .env + ### Server Configuration # HOST=0.0.0.0 # PORT=9621 +# WORKERS=1 # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances # CORS_ORIGINS=http://localhost:3000,http://localhost:8080 @@ -22,6 +25,9 @@ ### Logging level # LOG_LEVEL=INFO # VERBOSE=False +# LOG_DIR=/path/to/log/directory # Log file directory path, defaults to current working directory +# LOG_MAX_BYTES=10485760 # Log file max size in bytes, defaults to 10MB +# LOG_BACKUP_COUNT=5 # Number of backup files to keep, defaults to 5 ### Max async calls for LLM # MAX_ASYNC=4 @@ -138,3 +144,6 @@ MONGODB_GRAPH=false # deprecated (keep for backward compatibility) ### Qdrant QDRANT_URL=http://localhost:16333 # QDRANT_API_KEY=your-api-key + +### Redis +REDIS_URI=redis://localhost:6379 diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 86f18271..8f61f2f6 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -24,6 +24,8 @@ pip install -e ".[api]" ### Starting API Server with Default Settings +After installing LightRAG with API support, you can start LightRAG by this command: `lightrag-server` + LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends: * ollama @@ -92,10 +94,43 @@ LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai-ollama LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai # start with ollama llm and ollama embedding (no apikey is needed) -Light_server --llm-binding ollama --embedding-binding ollama +light-server --llm-binding ollama --embedding-binding ollama ``` +### Starting API Server with Gunicorn (Production) + +For production deployments, it's recommended to use Gunicorn as the WSGI server to handle concurrent requests efficiently. LightRAG provides a dedicated Gunicorn startup script that handles shared data initialization, process management, and other critical functionalities. + +```bash +# Start with lightrag-gunicorn command +lightrag-gunicorn --workers 4 + +# Alternatively, you can use the module directly +python -m lightrag.api.run_with_gunicorn --workers 4 +``` + +The `--workers` parameter is crucial for performance: + +- Determines how many worker processes Gunicorn will spawn to handle requests +- Each worker can handle concurrent requests using asyncio +- Recommended value is (2 x number_of_cores) + 1 +- For example, on a 4-core machine, use 9 workers: (2 x 4) + 1 = 9 +- Consider your server's memory when setting this value, as each worker consumes memory + +Other important startup parameters: + +- `--host`: Server listening address (default: 0.0.0.0) +- `--port`: Server listening port (default: 9621) +- `--timeout`: Request handling timeout (default: 150 seconds) +- `--log-level`: Logging level (default: INFO) +- `--ssl`: Enable HTTPS +- `--ssl-certfile`: Path to SSL certificate file +- `--ssl-keyfile`: Path to SSL private key file + +The command line parameters and enviroment variable run_with_gunicorn.py is exactly the same as `light-server`. + ### For Azure OpenAI Backend + Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)): ```bash # Change the resource group name, location and OpenAI resource name as needed @@ -186,7 +221,7 @@ LightRAG supports binding to various LLM/Embedding backends: * openai & openai compatible * azure_openai -Use environment variables `LLM_BINDING ` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING ` or CLI argument `--embedding-binding` to select LLM backend type. +Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type. ### Storage Types Supported diff --git a/lightrag/api/gunicorn_config.py b/lightrag/api/gunicorn_config.py new file mode 100644 index 00000000..7f9b4d58 --- /dev/null +++ b/lightrag/api/gunicorn_config.py @@ -0,0 +1,187 @@ +# gunicorn_config.py +import os +import logging +from lightrag.kg.shared_storage import finalize_share_data +from lightrag.api.lightrag_server import LightragPathFilter + +# Get log directory path from environment variable +log_dir = os.getenv("LOG_DIR", os.getcwd()) +log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + +# Get log file max size and backup count from environment variables +log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB +log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + +# These variables will be set by run_with_gunicorn.py +workers = None +bind = None +loglevel = None +certfile = None +keyfile = None + +# Enable preload_app option +preload_app = True + +# Use Uvicorn worker +worker_class = "uvicorn.workers.UvicornWorker" + +# Other Gunicorn configurations +timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py +keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s + +# Logging configuration +errorlog = os.getenv("ERROR_LOG", log_file_path) # Default write to lightrag.log +accesslog = os.getenv("ACCESS_LOG", log_file_path) # Default write to lightrag.log + +logconfig_dict = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "standard", + "stream": "ext://sys.stdout", + }, + "file": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "standard", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf8", + }, + }, + "filters": { + "path_filter": { + "()": "lightrag.api.lightrag_server.LightragPathFilter", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": loglevel.upper() if loglevel else "INFO", + "propagate": False, + }, + "gunicorn": { + "handlers": ["console", "file"], + "level": loglevel.upper() if loglevel else "INFO", + "propagate": False, + }, + "gunicorn.error": { + "handlers": ["console", "file"], + "level": loglevel.upper() if loglevel else "INFO", + "propagate": False, + }, + "gunicorn.access": { + "handlers": ["console", "file"], + "level": loglevel.upper() if loglevel else "INFO", + "propagate": False, + "filters": ["path_filter"], + }, + }, +} + + +def on_starting(server): + """ + Executed when Gunicorn starts, before forking the first worker processes + You can use this function to do more initialization tasks for all processes + """ + print("=" * 80) + print(f"GUNICORN MASTER PROCESS: on_starting jobs for {workers} worker(s)") + print(f"Process ID: {os.getpid()}") + print("=" * 80) + + # Memory usage monitoring + try: + import psutil + + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + msg = ( + f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB" + ) + print(msg) + except ImportError: + print("psutil not installed, skipping memory usage reporting") + + print("Gunicorn initialization complete, forking workers...\n") + + +def on_exit(server): + """ + Executed when Gunicorn is shutting down. + This is a good place to release shared resources. + """ + print("=" * 80) + print("GUNICORN MASTER PROCESS: Shutting down") + print(f"Process ID: {os.getpid()}") + print("=" * 80) + + # Release shared resources + finalize_share_data() + + print("=" * 80) + print("Gunicorn shutdown complete") + print("=" * 80) + + +def post_fork(server, worker): + """ + Executed after a worker has been forked. + This is a good place to set up worker-specific configurations. + """ + # Configure formatters + detailed_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + simple_formatter = logging.Formatter("%(levelname)s: %(message)s") + + def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False): + """Set up a logger with console and file handlers""" + logger_instance = logging.getLogger(logger_name) + logger_instance.setLevel(level) + logger_instance.handlers = [] # Clear existing handlers + logger_instance.propagate = False + + # Add console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(simple_formatter) + console_handler.setLevel(level) + logger_instance.addHandler(console_handler) + + # Add file handler + file_handler = logging.handlers.RotatingFileHandler( + filename=log_file_path, + maxBytes=log_max_bytes, + backupCount=log_backup_count, + encoding="utf-8", + ) + file_handler.setFormatter(detailed_formatter) + file_handler.setLevel(level) + logger_instance.addHandler(file_handler) + + # Add path filter if requested + if add_filter: + path_filter = LightragPathFilter() + logger_instance.addFilter(path_filter) + + # Set up main loggers + log_level = loglevel.upper() if loglevel else "INFO" + setup_logger("uvicorn", log_level) + setup_logger("uvicorn.access", log_level, add_filter=True) + setup_logger("lightrag", log_level, add_filter=True) + + # Set up lightrag submodule loggers + for name in logging.root.manager.loggerDict: + if name.startswith("lightrag."): + setup_logger(name, log_level, add_filter=True) + + # Disable uvicorn.error logger + uvicorn_error_logger = logging.getLogger("uvicorn.error") + uvicorn_error_logger.handlers = [] + uvicorn_error_logger.setLevel(logging.CRITICAL) + uvicorn_error_logger.propagate = False diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 9b2a1c76..5f2c437f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -8,11 +8,12 @@ from fastapi import ( ) from fastapi.responses import FileResponse import asyncio -import threading import os -from fastapi.staticfiles import StaticFiles import logging -from typing import Dict +import logging.config +import uvicorn +import pipmaster as pm +from fastapi.staticfiles import StaticFiles from pathlib import Path import configparser from ascii_colors import ASCIIColors @@ -29,7 +30,6 @@ from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc -from lightrag.utils import logger from .routers.document_routes import ( DocumentManager, create_document_routes, @@ -39,33 +39,25 @@ from .routers.query_routes import create_query_routes from .routers.graph_routes import create_graph_routes from .routers.ollama_api import OllamaAPI +from lightrag.utils import logger, set_verbose_debug +from lightrag.kg.shared_storage import ( + get_namespace_data, + get_pipeline_status_lock, + initialize_pipeline_status, + get_all_update_flags_status, +) + # Load environment variables -try: - load_dotenv(override=True) -except Exception as e: - logger.warning(f"Failed to load .env file: {e}") +load_dotenv(override=True) # Initialize config parser config = configparser.ConfigParser() config.read("config.ini") -# Global configuration -global_top_k = 60 # default value -# Global progress tracker -scan_progress: Dict = { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, -} +class LightragPathFilter(logging.Filter): + """Filter for lightrag logger to filter out frequent path access logs""" -# Lock for thread-safe operations -progress_lock = threading.Lock() - - -class AccessLogFilter(logging.Filter): def __init__(self): super().__init__() # Define paths to be filtered @@ -73,17 +65,18 @@ class AccessLogFilter(logging.Filter): def filter(self, record): try: + # Check if record has the required attributes for an access log if not hasattr(record, "args") or not isinstance(record.args, tuple): return True if len(record.args) < 5: return True + # Extract method, path and status from the record args method = record.args[1] path = record.args[2] status = record.args[4] - # print(f"Debug - Method: {method}, Path: {path}, Status: {status}") - # print(f"Debug - Filtered paths: {self.filtered_paths}") + # Filter out successful GET requests to filtered paths if ( method == "GET" and (status == 200 or status == 304) @@ -92,19 +85,14 @@ class AccessLogFilter(logging.Filter): return False return True - except Exception: + # In case of any error, let the message through return True def create_app(args): - # Set global top_k - global global_top_k - global_top_k = args.top_k # save top_k from args - - # Initialize verbose debug setting - from lightrag.utils import set_verbose_debug - + # Setup logging + logger.setLevel(args.log_level) set_verbose_debug(args.verbose) # Verify that bindings are correctly setup @@ -138,11 +126,6 @@ def create_app(args): if not os.path.exists(args.ssl_keyfile): raise Exception(f"SSL key file not found: {args.ssl_keyfile}") - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key @@ -158,28 +141,23 @@ def create_app(args): try: # Initialize database connections await rag.initialize_storages() + await initialize_pipeline_status() # Auto scan documents if enabled if args.auto_scan_at_startup: - # Start scanning in background - with progress_lock: - if not scan_progress["is_scanning"]: - scan_progress["is_scanning"] = True - scan_progress["indexed_count"] = 0 - scan_progress["progress"] = 0 - # Create background task - task = asyncio.create_task( - run_scanning_process(rag, doc_manager) - ) - app.state.background_tasks.add(task) - task.add_done_callback(app.state.background_tasks.discard) - ASCIIColors.info( - f"Started background scanning of documents from {args.input_dir}" - ) - else: - ASCIIColors.info( - "Skip document scanning(another scanning is active)" - ) + # Check if a task is already running (with lock protection) + pipeline_status = await get_namespace_data("pipeline_status") + should_start_task = False + async with get_pipeline_status_lock(): + if not pipeline_status.get("busy", False): + should_start_task = True + # Only start the task if no other task is running + if should_start_task: + # Create background task + task = asyncio.create_task(run_scanning_process(rag, doc_manager)) + app.state.background_tasks.add(task) + task.add_done_callback(app.state.background_tasks.discard) + logger.info("Auto scan task started at startup.") ASCIIColors.green("\nServer is ready to accept connections! ๐Ÿš€\n") @@ -398,6 +376,9 @@ def create_app(args): @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" + # Get update flags status for all namespaces + update_status = await get_all_update_flags_status() + return { "status": "healthy", "working_directory": str(args.working_dir), @@ -417,6 +398,7 @@ def create_app(args): "graph_storage": args.graph_storage, "vector_storage": args.vector_storage, }, + "update_status": update_status, } # Webui mount webui/index.html @@ -435,12 +417,30 @@ def create_app(args): return app -def main(): - args = parse_args() - import uvicorn - import logging.config +def get_application(args=None): + """Factory function for creating the FastAPI application""" + if args is None: + args = parse_args() + return create_app(args) + + +def configure_logging(): + """Configure logging for uvicorn startup""" + + # Reset any existing handlers to ensure clean configuration + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger = logging.getLogger(logger_name) + logger.handlers = [] + logger.filters = [] + + # Get log directory path from environment variable + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups - # Configure uvicorn logging logging.config.dictConfig( { "version": 1, @@ -449,36 +449,106 @@ def main(): "default": { "format": "%(levelname)s: %(message)s", }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, }, "handlers": { - "default": { + "console": { "formatter": "default", "class": "logging.StreamHandler", "stream": "ext://sys.stderr", }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, }, "loggers": { - "uvicorn.access": { - "handlers": ["default"], + # Configure all uvicorn related loggers + "uvicorn": { + "handlers": ["console", "file"], "level": "INFO", "propagate": False, }, + "uvicorn.access": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + "filters": ["path_filter"], + }, + "uvicorn.error": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + "filters": ["path_filter"], + }, + }, + "filters": { + "path_filter": { + "()": "lightrag.api.lightrag_server.LightragPathFilter", + }, }, } ) - # Add filter to uvicorn access logger - uvicorn_access_logger = logging.getLogger("uvicorn.access") - uvicorn_access_logger.addFilter(AccessLogFilter()) - app = create_app(args) +def check_and_install_dependencies(): + """Check and install required dependencies""" + required_packages = [ + "uvicorn", + "tiktoken", + "fastapi", + # Add other required packages here + ] + + for package in required_packages: + if not pm.is_installed(package): + print(f"Installing {package}...") + pm.install(package) + print(f"{package} installed successfully") + + +def main(): + # Check if running under Gunicorn + if "GUNICORN_CMD_ARGS" in os.environ: + # If started with Gunicorn, return directly as Gunicorn will call get_application + print("Running under Gunicorn - worker management handled by Gunicorn") + return + + # Check and install dependencies + check_and_install_dependencies() + + from multiprocessing import freeze_support + + freeze_support() + + # Configure logging before parsing args + configure_logging() + + args = parse_args(is_uvicorn_mode=True) display_splash_screen(args) + + # Create application instance directly instead of using factory function + app = create_app(args) + + # Start Uvicorn in single process mode uvicorn_config = { - "app": app, + "app": app, # Pass application instance directly instead of string path "host": args.host, "port": args.port, "log_config": None, # Disable default config } + if args.ssl: uvicorn_config.update( { @@ -486,6 +556,8 @@ def main(): "ssl_keyfile": args.ssl_keyfile, } ) + + print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}") uvicorn.run(**uvicorn_config) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 5c742f39..ab5aff96 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -3,8 +3,7 @@ This module contains all document-related routes for the LightRAG API. """ import asyncio -import logging -import os +from lightrag.utils import logger import aiofiles import shutil import traceback @@ -12,7 +11,6 @@ import pipmaster as pm from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Any - from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from pydantic import BaseModel, Field, field_validator @@ -23,18 +21,6 @@ from ..utils_api import get_api_key_dependency router = APIRouter(prefix="/documents", tags=["documents"]) -# Global progress tracker -scan_progress: Dict = { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, -} - -# Lock for thread-safe operations -progress_lock = asyncio.Lock() - # Temporary file prefix temp_prefix = "__tmp__" @@ -161,19 +147,12 @@ class DocumentManager: """Scan input directory for new files""" new_files = [] for ext in self.supported_extensions: - logging.debug(f"Scanning for {ext} files in {self.input_dir}") + logger.debug(f"Scanning for {ext} files in {self.input_dir}") for file_path in self.input_dir.rglob(f"*{ext}"): if file_path not in self.indexed_files: new_files.append(file_path) return new_files - # def scan_directory(self) -> List[Path]: - # new_files = [] - # for ext in self.supported_extensions: - # for file_path in self.input_dir.rglob(f"*{ext}"): - # new_files.append(file_path) - # return new_files - def mark_as_indexed(self, file_path: Path): self.indexed_files.add(file_path) @@ -287,7 +266,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ) content += "\n" case _: - logging.error( + logger.error( f"Unsupported file type: {file_path.name} (extension {ext})" ) return False @@ -295,20 +274,20 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: # Insert into the RAG queue if content: await rag.apipeline_enqueue_documents(content) - logging.info(f"Successfully fetched and enqueued file: {file_path.name}") + logger.info(f"Successfully fetched and enqueued file: {file_path.name}") return True else: - logging.error(f"No content could be extracted from file: {file_path.name}") + logger.error(f"No content could be extracted from file: {file_path.name}") except Exception as e: - logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") + logger.error(traceback.format_exc()) finally: if file_path.name.startswith(temp_prefix): try: file_path.unlink() except Exception as e: - logging.error(f"Error deleting file {file_path}: {str(e)}") + logger.error(f"Error deleting file {file_path}: {str(e)}") return False @@ -324,8 +303,8 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path): await rag.apipeline_process_enqueue_documents() except Exception as e: - logging.error(f"Error indexing file {file_path.name}: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error indexing file {file_path.name}: {str(e)}") + logger.error(traceback.format_exc()) async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): @@ -349,8 +328,8 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): if enqueued: await rag.apipeline_process_enqueue_documents() except Exception as e: - logging.error(f"Error indexing files: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error indexing files: {str(e)}") + logger.error(traceback.format_exc()) async def pipeline_index_texts(rag: LightRAG, texts: List[str]): @@ -393,30 +372,17 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): """Background task to scan and index documents""" try: new_files = doc_manager.scan_directory_for_new_files() - scan_progress["total_files"] = len(new_files) + total_files = len(new_files) + logger.info(f"Found {total_files} new files to index.") - logging.info(f"Found {len(new_files)} new files to index.") - for file_path in new_files: + for idx, file_path in enumerate(new_files): try: - async with progress_lock: - scan_progress["current_file"] = os.path.basename(file_path) - await pipeline_index_file(rag, file_path) - - async with progress_lock: - scan_progress["indexed_count"] += 1 - scan_progress["progress"] = ( - scan_progress["indexed_count"] / scan_progress["total_files"] - ) * 100 - except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") + logger.error(f"Error indexing file {file_path}: {str(e)}") except Exception as e: - logging.error(f"Error during scanning process: {str(e)}") - finally: - async with progress_lock: - scan_progress["is_scanning"] = False + logger.error(f"Error during scanning process: {str(e)}") def create_document_routes( @@ -436,34 +402,10 @@ def create_document_routes( Returns: dict: A dictionary containing the scanning status """ - async with progress_lock: - if scan_progress["is_scanning"]: - return {"status": "already_scanning"} - - scan_progress["is_scanning"] = True - scan_progress["indexed_count"] = 0 - scan_progress["progress"] = 0 - # Start the scanning process in the background background_tasks.add_task(run_scanning_process, rag, doc_manager) return {"status": "scanning_started"} - @router.get("/scan-progress") - async def get_scan_progress(): - """ - Get the current progress of the document scanning process. - - Returns: - dict: A dictionary containing the current scanning progress information including: - - is_scanning: Whether a scan is currently in progress - - current_file: The file currently being processed - - indexed_count: Number of files indexed so far - - total_files: Total number of files to process - - progress: Percentage of completion - """ - async with progress_lock: - return scan_progress - @router.post("/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir( background_tasks: BackgroundTasks, file: UploadFile = File(...) @@ -504,8 +446,8 @@ def create_document_routes( message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", ) except Exception as e: - logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/upload: {file.filename}: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -537,8 +479,8 @@ def create_document_routes( message="Text successfully received. Processing will continue in background.", ) except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/text: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -572,8 +514,8 @@ def create_document_routes( message="Text successfully received. Processing will continue in background.", ) except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/text: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -615,8 +557,8 @@ def create_document_routes( message=f"File '{file.filename}' saved successfully. Processing will continue in background.", ) except Exception as e: - logging.error(f"Error /documents/file: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/file: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -678,8 +620,8 @@ def create_document_routes( return InsertResponse(status=status, message=status_message) except Exception as e: - logging.error(f"Error /documents/batch: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/batch: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.delete( @@ -706,8 +648,42 @@ def create_document_routes( status="success", message="All documents cleared successfully" ) except Exception as e: - logging.error(f"Error DELETE /documents: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error DELETE /documents: {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/pipeline_status", dependencies=[Depends(optional_api_key)]) + async def get_pipeline_status(): + """ + Get the current status of the document indexing pipeline. + + This endpoint returns information about the current state of the document processing pipeline, + including whether it's busy, the current job name, when it started, how many documents + are being processed, how many batches there are, and which batch is currently being processed. + + Returns: + dict: A dictionary containing the pipeline status information + """ + try: + from lightrag.kg.shared_storage import get_namespace_data + + pipeline_status = await get_namespace_data("pipeline_status") + + # Convert to regular dict if it's a Manager.dict + status_dict = dict(pipeline_status) + + # Convert history_messages to a regular list if it's a Manager.list + if "history_messages" in status_dict: + status_dict["history_messages"] = list(status_dict["history_messages"]) + + # Format the job_start time if it exists + if status_dict.get("job_start"): + status_dict["job_start"] = str(status_dict["job_start"]) + + return status_dict + except Exception as e: + logger.error(f"Error getting pipeline status: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.get("", dependencies=[Depends(optional_api_key)]) @@ -763,8 +739,8 @@ def create_document_routes( ) return response except Exception as e: - logging.error(f"Error GET /documents: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error GET /documents: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) return router diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py new file mode 100644 index 00000000..903c5c17 --- /dev/null +++ b/lightrag/api/run_with_gunicorn.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +""" +Start LightRAG server with Gunicorn +""" + +import os +import sys +import signal +import pipmaster as pm +from lightrag.api.utils_api import parse_args, display_splash_screen +from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data + + +def check_and_install_dependencies(): + """Check and install required dependencies""" + required_packages = [ + "gunicorn", + "tiktoken", + "psutil", + # Add other required packages here + ] + + for package in required_packages: + if not pm.is_installed(package): + print(f"Installing {package}...") + pm.install(package) + print(f"{package} installed successfully") + + +# Signal handler for graceful shutdown +def signal_handler(sig, frame): + print("\n\n" + "=" * 80) + print("RECEIVED TERMINATION SIGNAL") + print(f"Process ID: {os.getpid()}") + print("=" * 80 + "\n") + + # Release shared resources + finalize_share_data() + + # Exit with success status + sys.exit(0) + + +def main(): + # Check and install dependencies + check_and_install_dependencies() + + # Register signal handlers for graceful shutdown + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # kill command + + # Parse all arguments using parse_args + args = parse_args(is_uvicorn_mode=False) + + # Display startup information + display_splash_screen(args) + + print("๐Ÿš€ Starting LightRAG with Gunicorn") + print(f"๐Ÿ”„ Worker management: Gunicorn (workers={args.workers})") + print("๐Ÿ” Preloading app: Enabled") + print("๐Ÿ“ Note: Using Gunicorn's preload feature for shared data initialization") + print("\n\n" + "=" * 80) + print("MAIN PROCESS INITIALIZATION") + print(f"Process ID: {os.getpid()}") + print(f"Workers setting: {args.workers}") + print("=" * 80 + "\n") + + # Import Gunicorn's StandaloneApplication + from gunicorn.app.base import BaseApplication + + # Define a custom application class that loads our config + class GunicornApp(BaseApplication): + def __init__(self, app, options=None): + self.options = options or {} + self.application = app + super().__init__() + + def load_config(self): + # Define valid Gunicorn configuration options + valid_options = { + "bind", + "workers", + "worker_class", + "timeout", + "keepalive", + "preload_app", + "errorlog", + "accesslog", + "loglevel", + "certfile", + "keyfile", + "limit_request_line", + "limit_request_fields", + "limit_request_field_size", + "graceful_timeout", + "max_requests", + "max_requests_jitter", + } + + # Special hooks that need to be set separately + special_hooks = { + "on_starting", + "on_reload", + "on_exit", + "pre_fork", + "post_fork", + "pre_exec", + "pre_request", + "post_request", + "worker_init", + "worker_exit", + "nworkers_changed", + "child_exit", + } + + # Import and configure the gunicorn_config module + from lightrag.api import gunicorn_config + + # Set configuration variables in gunicorn_config, prioritizing command line arguments + gunicorn_config.workers = ( + args.workers if args.workers else int(os.getenv("WORKERS", 1)) + ) + + # Bind configuration prioritizes command line arguments + host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") + port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) + gunicorn_config.bind = f"{host}:{port}" + + # Log level configuration prioritizes command line arguments + gunicorn_config.loglevel = ( + args.log_level.lower() + if args.log_level + else os.getenv("LOG_LEVEL", "info") + ) + + # Timeout configuration prioritizes command line arguments + gunicorn_config.timeout = ( + args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) + ) + + # Keepalive configuration + gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) + + # SSL configuration prioritizes command line arguments + if args.ssl or os.getenv("SSL", "").lower() in ( + "true", + "1", + "yes", + "t", + "on", + ): + gunicorn_config.certfile = ( + args.ssl_certfile + if args.ssl_certfile + else os.getenv("SSL_CERTFILE") + ) + gunicorn_config.keyfile = ( + args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") + ) + + # Set configuration options from the module + for key in dir(gunicorn_config): + if key in valid_options: + value = getattr(gunicorn_config, key) + # Skip functions like on_starting and None values + if not callable(value) and value is not None: + self.cfg.set(key, value) + # Set special hooks + elif key in special_hooks: + value = getattr(gunicorn_config, key) + if callable(value): + self.cfg.set(key, value) + + if hasattr(gunicorn_config, "logconfig_dict"): + self.cfg.set( + "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") + ) + + def load(self): + # Import the application + from lightrag.api.lightrag_server import get_application + + return get_application(args) + + # Create the application + app = GunicornApp("") + + # Force workers to be an integer and greater than 1 for multi-process mode + workers_count = int(args.workers) + if workers_count > 1: + # Set a flag to indicate we're in the main process + os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" + initialize_share_data(workers_count) + else: + initialize_share_data(1) + + # Run the application + print("\nStarting Gunicorn with direct Python API...") + app.run() + + +if __name__ == "__main__": + main() diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 17f19627..ed1250d4 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -6,6 +6,7 @@ import os import argparse from typing import Optional import sys +import logging from ascii_colors import ASCIIColors from lightrag.api import __api_version__ from fastapi import HTTPException, Security @@ -110,10 +111,13 @@ def get_env_value(env_key: str, default: any, value_type: type = str) -> any: return default -def parse_args() -> argparse.Namespace: +def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: """ Parse command line arguments with environment variable fallback + Args: + is_uvicorn_mode: Whether running under uvicorn mode + Returns: argparse.Namespace: Parsed arguments """ @@ -260,6 +264,14 @@ def parse_args() -> argparse.Namespace: help="Enable automatic scanning when the program starts", ) + # Server workers configuration + parser.add_argument( + "--workers", + type=int, + default=get_env_value("WORKERS", 1, int), + help="Number of worker processes (default: from env or 1)", + ) + # LLM and embedding bindings parser.add_argument( "--llm-binding", @@ -278,6 +290,15 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() + # If in uvicorn mode and workers > 1, force it to 1 and log warning + if is_uvicorn_mode and args.workers > 1: + original_workers = args.workers + args.workers = 1 + # Log warning directly here + logging.warning( + f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" + ) + # convert relative path to absolute path args.working_dir = os.path.abspath(args.working_dir) args.input_dir = os.path.abspath(args.input_dir) @@ -346,17 +367,27 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.host}") ASCIIColors.white(" โ”œโ”€ Port: ", end="") ASCIIColors.yellow(f"{args.port}") + ASCIIColors.white(" โ”œโ”€ Workers: ", end="") + ASCIIColors.yellow(f"{args.workers}") ASCIIColors.white(" โ”œโ”€ CORS Origins: ", end="") ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") ASCIIColors.white(" โ”œโ”€ SSL Enabled: ", end="") ASCIIColors.yellow(f"{args.ssl}") - ASCIIColors.white(" โ””โ”€ API Key: ", end="") - ASCIIColors.yellow("Set" if args.key else "Not Set") if args.ssl: ASCIIColors.white(" โ”œโ”€ SSL Cert: ", end="") ASCIIColors.yellow(f"{args.ssl_certfile}") - ASCIIColors.white(" โ””โ”€ SSL Key: ", end="") + ASCIIColors.white(" โ”œโ”€ SSL Key: ", end="") ASCIIColors.yellow(f"{args.ssl_keyfile}") + ASCIIColors.white(" โ”œโ”€ Ollama Emulating Model: ", end="") + ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") + ASCIIColors.white(" โ”œโ”€ Log Level: ", end="") + ASCIIColors.yellow(f"{args.log_level}") + ASCIIColors.white(" โ”œโ”€ Verbose Debug: ", end="") + ASCIIColors.yellow(f"{args.verbose}") + ASCIIColors.white(" โ”œโ”€ Timeout: ", end="") + ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") + ASCIIColors.white(" โ””โ”€ API Key: ", end="") + ASCIIColors.yellow("Set" if args.key else "Not Set") # Directory Configuration ASCIIColors.magenta("\n๐Ÿ“‚ Directory Configuration:") @@ -415,16 +446,6 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.white(" โ””โ”€ Document Status Storage: ", end="") ASCIIColors.yellow(f"{args.doc_status_storage}") - ASCIIColors.magenta("\n๐Ÿ› ๏ธ System Configuration:") - ASCIIColors.white(" โ”œโ”€ Ollama Emulating Model: ", end="") - ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") - ASCIIColors.white(" โ”œโ”€ Log Level: ", end="") - ASCIIColors.yellow(f"{args.log_level}") - ASCIIColors.white(" โ”œโ”€ Verbose Debug: ", end="") - ASCIIColors.yellow(f"{args.verbose}") - ASCIIColors.white(" โ””โ”€ Timeout: ", end="") - ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") - # Server Status ASCIIColors.green("\nโœจ Server starting up...\n") @@ -478,7 +499,6 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.cyan(""" 3. Basic Operations: - POST /upload_document: Upload new documents to RAG - POST /query: Query your document collection - - GET /collections: List available collections 4. Monitor the server: - Check server logs for detailed operation information diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 2ac0899e..940ba73d 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -2,25 +2,25 @@ import os import time import asyncio from typing import Any, final - import json import numpy as np from dataclasses import dataclass import pipmaster as pm -from lightrag.utils import ( - logger, - compute_mdhash_id, -) -from lightrag.base import ( - BaseVectorStorage, -) +from lightrag.utils import logger, compute_mdhash_id +from lightrag.base import BaseVectorStorage if not pm.is_installed("faiss"): pm.install("faiss") -import faiss +import faiss # type: ignore +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @@ -55,14 +55,40 @@ class FaissVectorDBStorage(BaseVectorStorage): # If you have a large number of vectors, you might want IVF or other indexes. # For demonstration, we use a simple IndexFlatIP. self._index = faiss.IndexFlatIP(self._dim) - # Keep a local store for metadata, IDs, etc. # Maps โ†’ metadata (including your original ID). self._id_to_meta = {} - # Attempt to load an existing index + metadata from disk self._load_faiss_index() + async def initialize(self): + """Initialize storage data""" + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() + + async def _get_index(self): + """Check if the shtorage should be reloaded""" + # Acquire lock to prevent concurrent read and write + async with self._storage_lock: + # Check if storage was updated by another process + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process" + ) + # Reload data + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + return self._index + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -113,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage): ) return [] - # Normalize embeddings for cosine similarity (in-place) + # Convert to float32 and normalize embeddings for cosine similarity (in-place) + embeddings = embeddings.astype(np.float32) faiss.normalize_L2(embeddings) # Upsert logic: @@ -127,18 +154,19 @@ class FaissVectorDBStorage(BaseVectorStorage): existing_ids_to_remove.append(faiss_internal_id) if existing_ids_to_remove: - self._remove_faiss_ids(existing_ids_to_remove) + await self._remove_faiss_ids(existing_ids_to_remove) # Step 2: Add new vectors - start_idx = self._index.ntotal - self._index.add(embeddings) + index = await self._get_index() + start_idx = index.ntotal + index.add(embeddings) # Step 3: Store metadata + vector for each new ID for i, meta in enumerate(list_data): fid = start_idx + i # Store the raw vector so we can rebuild if something is removed meta["__vector__"] = embeddings[i].tolist() - self._id_to_meta[fid] = meta + self._id_to_meta.update({fid: meta}) logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") return [m["__id__"] for m in list_data] @@ -157,7 +185,8 @@ class FaissVectorDBStorage(BaseVectorStorage): ) # Perform the similarity search - distances, indices = self._index.search(embedding, top_k) + index = await self._get_index() + distances, indices = index.search(embedding, top_k) distances = distances[0] indices = indices[0] @@ -201,8 +230,8 @@ class FaissVectorDBStorage(BaseVectorStorage): to_remove.append(fid) if to_remove: - self._remove_faiss_ids(to_remove) - logger.info( + await self._remove_faiss_ids(to_remove) + logger.debug( f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" ) @@ -223,12 +252,9 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Found {len(relations)} relations for {entity_name}") if relations: - self._remove_faiss_ids(relations) + await self._remove_faiss_ids(relations) logger.debug(f"Deleted {len(relations)} relations for {entity_name}") - async def index_done_callback(self) -> None: - self._save_faiss_index() - # -------------------------------------------------------------------------------- # Internal helper methods # -------------------------------------------------------------------------------- @@ -242,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage): return fid return None - def _remove_faiss_ids(self, fid_list): + async def _remove_faiss_ids(self, fid_list): """ Remove a list of internal Faiss IDs from the index. Because IndexFlatIP doesn't support 'removals', @@ -258,13 +284,14 @@ class FaissVectorDBStorage(BaseVectorStorage): vectors_to_keep.append(vec_meta["__vector__"]) # stored as list new_id_to_meta[new_fid] = vec_meta - # Re-init index - self._index = faiss.IndexFlatIP(self._dim) - if vectors_to_keep: - arr = np.array(vectors_to_keep, dtype=np.float32) - self._index.add(arr) + async with self._storage_lock: + # Re-init index + self._index = faiss.IndexFlatIP(self._dim) + if vectors_to_keep: + arr = np.array(vectors_to_keep, dtype=np.float32) + self._index.add(arr) - self._id_to_meta = new_id_to_meta + self._id_to_meta = new_id_to_meta def _save_faiss_index(self): """ @@ -312,3 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.warning("Starting with an empty Faiss index.") self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} + + async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Storage for FAISS {self.namespace} was updated by another process, reloading..." + ) + async with self._storage_lock: + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + self._save_faiss_index() + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + except Exception as e: + logger.error(f"Error saving FAISS index for {self.namespace}: {e}") + return False # Return error + + return True # Return success diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 63a295cd..01c657fa 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -12,6 +12,11 @@ from lightrag.utils import ( logger, write_json, ) +from .shared_storage import ( + get_namespace_data, + get_storage_lock, + try_initialize_namespace, +) @final @@ -22,26 +27,42 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data: dict[str, Any] = load_json(self._file_name) or {} - logger.info(f"Loaded document status storage with {len(self._data)} records") + self._storage_lock = get_storage_lock() + self._data = None + + async def initialize(self): + """Initialize storage data""" + # check need_init must before get_namespace_data + need_init = try_initialize_namespace(self.namespace) + self._data = await get_namespace_data(self.namespace) + if need_init: + loaded_data = load_json(self._file_name) or {} + async with self._storage_lock: + self._data.update(loaded_data) + logger.info( + f"Loaded document status storage with {len(loaded_data)} records" + ) async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - return set(keys) - set(self._data.keys()) + async with self._storage_lock: + return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] - for id in ids: - data = self._data.get(id, None) - if data: - result.append(data) + async with self._storage_lock: + for id in ids: + data = self._data.get(id, None) + if data: + result.append(data) return result async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" counts = {status.value: 0 for status in DocStatus} - for doc in self._data.values(): - counts[doc["status"]] += 1 + async with self._storage_lock: + for doc in self._data.values(): + counts[doc["status"]] += 1 return counts async def get_docs_by_status( @@ -49,39 +70,48 @@ class JsonDocStatusStorage(DocStatusStorage): ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" result = {} - for k, v in self._data.items(): - if v["status"] == status.value: - try: - # Make a copy of the data to avoid modifying the original - data = v.copy() - # If content is missing, use content_summary as content - if "content" not in data and "content_summary" in data: - data["content"] = data["content_summary"] - result[k] = DocProcessingStatus(**data) - except KeyError as e: - logger.error(f"Missing required field for document {k}: {e}") - continue + async with self._storage_lock: + for k, v in self._data.items(): + if v["status"] == status.value: + try: + # Make a copy of the data to avoid modifying the original + data = v.copy() + # If content is missing, use content_summary as content + if "content" not in data and "content_summary" in data: + data["content"] = data["content_summary"] + result[k] = DocProcessingStatus(**data) + except KeyError as e: + logger.error(f"Missing required field for document {k}: {e}") + continue return result async def index_done_callback(self) -> None: - write_json(self._data, self._file_name) + async with self._storage_lock: + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) + write_json(data_dict, self._file_name) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - self._data.update(data) + async with self._storage_lock: + self._data.update(data) await self.index_done_callback() async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.get(id) + async with self._storage_lock: + return self._data.get(id) async def delete(self, doc_ids: list[str]): - for doc_id in doc_ids: - self._data.pop(doc_id, None) + async with self._storage_lock: + for doc_id in doc_ids: + self._data.pop(doc_id, None) await self.index_done_callback() async def drop(self) -> None: """Drop the storage""" - self._data.clear() + async with self._storage_lock: + self._data.clear() diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index e1ea507a..8d707899 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,4 +1,3 @@ -import asyncio import os from dataclasses import dataclass from typing import Any, final @@ -11,6 +10,11 @@ from lightrag.utils import ( logger, write_json, ) +from .shared_storage import ( + get_namespace_data, + get_storage_lock, + try_initialize_namespace, +) @final @@ -19,37 +23,56 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data: dict[str, Any] = load_json(self._file_name) or {} - self._lock = asyncio.Lock() - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + self._storage_lock = get_storage_lock() + self._data = None + + async def initialize(self): + """Initialize storage data""" + # check need_init must before get_namespace_data + need_init = try_initialize_namespace(self.namespace) + self._data = await get_namespace_data(self.namespace) + if need_init: + loaded_data = load_json(self._file_name) or {} + async with self._storage_lock: + self._data.update(loaded_data) + logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") async def index_done_callback(self) -> None: - write_json(self._data, self._file_name) + async with self._storage_lock: + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) + write_json(data_dict, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: - return self._data.get(id) + async with self._storage_lock: + return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return [ - ( - {k: v for k, v in self._data[id].items()} - if self._data.get(id, None) - else None - ) - for id in ids - ] + async with self._storage_lock: + return [ + ( + {k: v for k, v in self._data[id].items()} + if self._data.get(id, None) + else None + ) + for id in ids + ] async def filter_keys(self, keys: set[str]) -> set[str]: - return set(keys) - set(self._data.keys()) + async with self._storage_lock: + return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) + async with self._storage_lock: + left_data = {k: v for k, v in data.items() if k not in self._data} + self._data.update(left_data) async def delete(self, ids: list[str]) -> None: - for doc_id in ids: - self._data.pop(doc_id, None) + async with self._storage_lock: + for doc_id in ids: + self._data.pop(doc_id, None) await self.index_done_callback() diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index b0900095..07c800de 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -3,7 +3,6 @@ import os from typing import Any, final from dataclasses import dataclass import numpy as np - import time from lightrag.utils import ( @@ -11,22 +10,29 @@ from lightrag.utils import ( compute_mdhash_id, ) import pipmaster as pm -from lightrag.base import ( - BaseVectorStorage, -) +from lightrag.base import BaseVectorStorage if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): - # Initialize lock only for file operations - self._save_lock = asyncio.Lock() + # Initialize basic attributes + self._client = None + self._storage_lock = None + self.storage_updated = None + # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -40,10 +46,43 @@ class NanoVectorDBStorage(BaseVectorStorage): self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] + self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, ) + async def initialize(self): + """Initialize storage data""" + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() + + async def _get_client(self): + """Check if the storage should be reloaded""" + # Acquire lock to prevent concurrent read and write + async with self._storage_lock: + # Check if data needs to be reloaded + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} reloading {self.namespace} due to update by another process" + ) + # Reload data + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + # Reset update flag + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + + return self._client + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: @@ -64,6 +103,7 @@ class NanoVectorDBStorage(BaseVectorStorage): for i in range(0, len(contents), self._max_batch_size) ] + # Execute embedding outside of lock to avoid long lock times embedding_tasks = [self.embedding_func(batch) for batch in batches] embeddings_list = await asyncio.gather(*embedding_tasks) @@ -71,7 +111,8 @@ class NanoVectorDBStorage(BaseVectorStorage): if len(embeddings) == len(list_data): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - results = self._client.upsert(datas=list_data) + client = await self._get_client() + results = client.upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -80,9 +121,12 @@ class NanoVectorDBStorage(BaseVectorStorage): ) async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + # Execute embedding outside of lock to avoid long lock times embedding = await self.embedding_func([query]) embedding = embedding[0] - results = self._client.query( + + client = await self._get_client() + results = client.query( query=embedding, top_k=top_k, better_than_threshold=self.cosine_better_than_threshold, @@ -99,8 +143,9 @@ class NanoVectorDBStorage(BaseVectorStorage): return results @property - def client_storage(self): - return getattr(self._client, "_NanoVectorDB__storage") + async def client_storage(self): + client = await self._get_client() + return getattr(client, "_NanoVectorDB__storage") async def delete(self, ids: list[str]): """Delete vectors with specified IDs @@ -109,8 +154,9 @@ class NanoVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - self._client.delete(ids) - logger.info( + client = await self._get_client() + client.delete(ids) + logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) except Exception as e: @@ -122,9 +168,11 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.debug( f"Attempting to delete entity {entity_name} with ID {entity_id}" ) + # Check if the entity exists - if self._client.get([entity_id]): - await self.delete([entity_id]) + client = await self._get_client() + if client.get([entity_id]): + client.delete([entity_id]) logger.debug(f"Successfully deleted entity {entity_name}") else: logger.debug(f"Entity {entity_name} not found in storage") @@ -133,16 +181,19 @@ class NanoVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: try: + client = await self._get_client() + storage = getattr(client, "_NanoVectorDB__storage") relations = [ dp - for dp in self.client_storage["data"] + for dp in storage["data"] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name ] logger.debug(f"Found {len(relations)} relations for entity {entity_name}") ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: - await self.delete(ids_to_delete) + client = await self._get_client() + client.delete(ids_to_delete) logger.debug( f"Deleted {len(ids_to_delete)} relations for {entity_name}" ) @@ -151,6 +202,37 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") - async def index_done_callback(self) -> None: - async with self._save_lock: - self._client.save() + async def index_done_callback(self) -> bool: + """Save data to disk""" + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Storage for {self.namespace} was updated by another process, reloading..." + ) + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + # Reset update flag + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + self._client.save() + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + return True # Return success + except Exception as e: + logger.error(f"Error saving data for {self.namespace}: {e}") + return False # Return error + + return True # Return success diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 1f5d34d0..f11e9c0e 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -1,18 +1,12 @@ import os from dataclasses import dataclass from typing import Any, final - import numpy as np - from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from lightrag.utils import ( - logger, -) +from lightrag.utils import logger +from lightrag.base import BaseGraphStorage -from lightrag.base import ( - BaseGraphStorage, -) import pipmaster as pm if not pm.is_installed("networkx"): @@ -23,6 +17,12 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @@ -78,56 +78,101 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) + self._storage_lock = None + self.storage_updated = None + self._graph = None + + # Load initial graph preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) if preloaded_graph is not None: logger.info( f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) + else: + logger.info("Created new empty graph") self._graph = preloaded_graph or nx.Graph() + self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } - async def index_done_callback(self) -> None: - NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) + async def initialize(self): + """Initialize storage data""" + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() + + async def _get_graph(self): + """Check if the storage should be reloaded""" + # Acquire lock to prevent concurrent read and write + async with self._storage_lock: + # Check if data needs to be reloaded + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process" + ) + # Reload data + self._graph = ( + NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + ) + # Reset update flag + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + + return self._graph async def has_node(self, node_id: str) -> bool: - return self._graph.has_node(node_id) + graph = await self._get_graph() + return graph.has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - return self._graph.has_edge(source_node_id, target_node_id) + graph = await self._get_graph() + return graph.has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: - return self._graph.nodes.get(node_id) + graph = await self._get_graph() + return graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: - return self._graph.degree(node_id) + graph = await self._get_graph() + return graph.degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._graph.degree(src_id) + self._graph.degree(tgt_id) + graph = await self._get_graph() + return graph.degree(src_id) + graph.degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - return self._graph.edges.get((source_node_id, target_node_id)) + graph = await self._get_graph() + return graph.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - if self._graph.has_node(source_node_id): - return list(self._graph.edges(source_node_id)) + graph = await self._get_graph() + if graph.has_node(source_node_id): + return list(graph.edges(source_node_id)) return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - self._graph.add_node(node_id, **node_data) + graph = await self._get_graph() + graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - self._graph.add_edge(source_node_id, target_node_id, **edge_data) + graph = await self._get_graph() + graph.add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: - if self._graph.has_node(node_id): - self._graph.remove_node(node_id) - logger.info(f"Node {node_id} deleted from the graph.") + graph = await self._get_graph() + if graph.has_node(node_id): + graph.remove_node(node_id) + logger.debug(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") @@ -138,35 +183,37 @@ class NetworkXStorage(BaseGraphStorage): raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() - # @TODO: NOT USED + # TODO: NOT USED async def _node2vec_embed(self): + graph = await self._get_graph() embeddings, nodes = embed.node2vec_embed( - self._graph, + graph, **self.global_config["node2vec_params"], ) - - nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] + nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids - def remove_nodes(self, nodes: list[str]): + async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes Args: nodes: List of node IDs to be deleted """ + graph = await self._get_graph() for node in nodes: - if self._graph.has_node(node): - self._graph.remove_node(node) + if graph.has_node(node): + graph.remove_node(node) - def remove_edges(self, edges: list[tuple[str, str]]): + async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ + graph = await self._get_graph() for source, target in edges: - if self._graph.has_edge(source, target): - self._graph.remove_edge(source, target) + if graph.has_edge(source, target): + graph.remove_edge(source, target) async def get_all_labels(self) -> list[str]: """ @@ -174,8 +221,9 @@ class NetworkXStorage(BaseGraphStorage): Returns: [label1, label2, ...] # Alphabetically sorted label list """ + graph = await self._get_graph() labels = set() - for node in self._graph.nodes(): + for node in graph.nodes(): labels.add(str(node)) # Add node id as a label # Return sorted list @@ -198,16 +246,18 @@ class NetworkXStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() + graph = await self._get_graph() + # Handle special case for "*" label if node_label == "*": # For "*", return the entire graph including all nodes and edges subgraph = ( - self._graph.copy() + graph.copy() ) # Create a copy to avoid modifying the original graph else: # Find nodes with matching node id (partial match) nodes_to_explore = [] - for n, attr in self._graph.nodes(data=True): + for n, attr in graph.nodes(data=True): if node_label in str(n): # Use partial matching nodes_to_explore.append(n) @@ -216,7 +266,7 @@ class NetworkXStorage(BaseGraphStorage): return result # Get subgraph using ego_graph - subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth) + subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) # Check if number of nodes exceeds max_graph_nodes max_graph_nodes = 500 @@ -278,9 +328,41 @@ class NetworkXStorage(BaseGraphStorage): ) seen_edges.add(edge_id) - # logger.info(result.edges) - logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result + + async def index_done_callback(self) -> bool: + """Save data to disk""" + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Graph for {self.namespace} was updated by another process, reloading..." + ) + self._graph = ( + NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + ) + # Reset update flag + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + return True # Return success + except Exception as e: + logger.error(f"Error saving graph for {self.namespace}: {e}") + return False # Return error + + return True diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index c91d23f0..51044be5 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -38,8 +38,8 @@ import pipmaster as pm if not pm.is_installed("asyncpg"): pm.install("asyncpg") -import asyncpg -from asyncpg import Pool +import asyncpg # type: ignore +from asyncpg import Pool # type: ignore class PostgreSQLDB: diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py new file mode 100644 index 00000000..c8c154aa --- /dev/null +++ b/lightrag/kg/shared_storage.py @@ -0,0 +1,374 @@ +import os +import sys +import asyncio +from multiprocessing.synchronize import Lock as ProcessLock +from multiprocessing import Manager +from typing import Any, Dict, Optional, Union, TypeVar, Generic + + +# Define a direct print function for critical logs that must be visible in all processes +def direct_log(message, level="INFO"): + """ + Log a message directly to stderr to ensure visibility in all processes, + including the Gunicorn master process. + """ + print(f"{level}: {message}", file=sys.stderr, flush=True) + + +T = TypeVar("T") +LockType = Union[ProcessLock, asyncio.Lock] + +is_multiprocess = None +_workers = None +_manager = None +_initialized = None + +# shared data for storage across processes +_shared_dicts: Optional[Dict[str, Any]] = None +_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized +_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated + +# locks for mutex access +_storage_lock: Optional[LockType] = None +_internal_lock: Optional[LockType] = None +_pipeline_status_lock: Optional[LockType] = None + + +class UnifiedLock(Generic[T]): + """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" + + def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): + self._lock = lock + self._is_async = is_async + + async def __aenter__(self) -> "UnifiedLock[T]": + if self._is_async: + await self._lock.acquire() + else: + self._lock.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._is_async: + self._lock.release() + else: + self._lock.release() + + def __enter__(self) -> "UnifiedLock[T]": + """For backward compatibility""" + if self._is_async: + raise RuntimeError("Use 'async with' for shared_storage lock") + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """For backward compatibility""" + if self._is_async: + raise RuntimeError("Use 'async with' for shared_storage lock") + self._lock.release() + + +def get_internal_lock() -> UnifiedLock: + """return unified storage lock for data consistency""" + return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess) + + +def get_storage_lock() -> UnifiedLock: + """return unified storage lock for data consistency""" + return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess) + + +def get_pipeline_status_lock() -> UnifiedLock: + """return unified storage lock for data consistency""" + return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess) + + +def initialize_share_data(workers: int = 1): + """ + Initialize shared storage data for single or multi-process mode. + + When used with Gunicorn's preload feature, this function is called once in the + master process before forking worker processes, allowing all workers to share + the same initialized data. + + In single-process mode, this function is called in FASTAPI lifespan function. + + The function determines whether to use cross-process shared variables for data storage + based on the number of workers. If workers=1, it uses thread locks and local dictionaries. + If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager. + + Args: + workers (int): Number of worker processes. If 1, single-process mode is used. + If > 1, multi-process mode with shared memory is used. + """ + global \ + _manager, \ + _workers, \ + is_multiprocess, \ + _storage_lock, \ + _internal_lock, \ + _pipeline_status_lock, \ + _shared_dicts, \ + _init_flags, \ + _initialized, \ + _update_flags + + # Check if already initialized + if _initialized: + direct_log( + f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})" + ) + return + + _manager = Manager() + _workers = workers + + if workers > 1: + is_multiprocess = True + _internal_lock = _manager.Lock() + _storage_lock = _manager.Lock() + _pipeline_status_lock = _manager.Lock() + _shared_dicts = _manager.dict() + _init_flags = _manager.dict() + _update_flags = _manager.dict() + direct_log( + f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" + ) + else: + is_multiprocess = False + _internal_lock = asyncio.Lock() + _storage_lock = asyncio.Lock() + _pipeline_status_lock = asyncio.Lock() + _shared_dicts = {} + _init_flags = {} + _update_flags = {} + direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") + + # Mark as initialized + _initialized = True + + +async def initialize_pipeline_status(): + """ + Initialize pipeline namespace with default values. + This function is called during FASTAPI lifespan for each worker. + """ + pipeline_namespace = await get_namespace_data("pipeline_status") + + async with get_internal_lock(): + # Check if already initialized by checking for required fields + if "busy" in pipeline_namespace: + return + + # Create a shared list object for history_messages + history_messages = _manager.list() if is_multiprocess else [] + pipeline_namespace.update( + { + "busy": False, # Control concurrent processes + "job_name": "Default Job", # Current job name (indexing files/indexing texts) + "job_start": None, # Job start time + "docs": 0, # Total number of documents to be indexed + "batchs": 0, # Number of batches for processing documents + "cur_batch": 0, # Current processing batch + "request_pending": False, # Flag for pending request for processing + "latest_message": "", # Latest message from pipeline processing + "history_messages": history_messages, # ไฝฟ็”จๅ…ฑไบซๅˆ—่กจๅฏน่ฑก + } + ) + direct_log(f"Process {os.getpid()} Pipeline namespace initialized") + + +async def get_update_flag(namespace: str): + """ + Create a namespace's update flag for a workers. + Returen the update flag to caller for referencing or reset. + """ + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + async with get_internal_lock(): + if namespace not in _update_flags: + if is_multiprocess and _manager is not None: + _update_flags[namespace] = _manager.list() + else: + _update_flags[namespace] = [] + direct_log( + f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]" + ) + + if is_multiprocess and _manager is not None: + new_update_flag = _manager.Value("b", False) + else: + new_update_flag = False + + _update_flags[namespace].append(new_update_flag) + return new_update_flag + + +async def set_all_update_flags(namespace: str): + """Set all update flag of namespace indicating all workers need to reload data from files""" + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + async with get_internal_lock(): + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for both modes + for i in range(len(_update_flags[namespace])): + if is_multiprocess: + _update_flags[namespace][i].value = True + else: + _update_flags[namespace][i] = True + + +async def get_all_update_flags_status() -> Dict[str, list]: + """ + Get update flags status for all namespaces. + + Returns: + Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses + """ + if _update_flags is None: + return {} + + result = {} + async with get_internal_lock(): + for namespace, flags in _update_flags.items(): + worker_statuses = [] + for flag in flags: + if is_multiprocess: + worker_statuses.append(flag.value) + else: + worker_statuses.append(flag) + result[namespace] = worker_statuses + + return result + + +def try_initialize_namespace(namespace: str) -> bool: + """ + Returns True if the current worker(process) gets initialization permission for loading data later. + The worker does not get the permission is prohibited to load data from files. + """ + global _init_flags, _manager + + if _init_flags is None: + raise ValueError("Try to create nanmespace before Shared-Data is initialized") + + if namespace not in _init_flags: + _init_flags[namespace] = True + direct_log( + f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + ) + return True + direct_log( + f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" + ) + return False + + +async def get_namespace_data(namespace: str) -> Dict[str, Any]: + """get the shared data reference for specific namespace""" + if _shared_dicts is None: + direct_log( + f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}", + level="ERROR", + ) + raise ValueError("Shared dictionaries not initialized") + + async with get_internal_lock(): + if namespace not in _shared_dicts: + if is_multiprocess and _manager is not None: + _shared_dicts[namespace] = _manager.dict() + else: + _shared_dicts[namespace] = {} + + return _shared_dicts[namespace] + + +def finalize_share_data(): + """ + Release shared resources and clean up. + + This function should be called when the application is shutting down + to properly release shared resources and avoid memory leaks. + + In multi-process mode, it shuts down the Manager and releases all shared objects. + In single-process mode, it simply resets the global variables. + """ + global \ + _manager, \ + is_multiprocess, \ + _storage_lock, \ + _internal_lock, \ + _pipeline_status_lock, \ + _shared_dicts, \ + _init_flags, \ + _initialized, \ + _update_flags + + # Check if already initialized + if not _initialized: + direct_log( + f"Process {os.getpid()} storage data not initialized, nothing to finalize" + ) + return + + direct_log( + f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})" + ) + + # In multi-process mode, shut down the Manager + if is_multiprocess and _manager is not None: + try: + # Clear shared resources before shutting down Manager + if _shared_dicts is not None: + # Clear pipeline status history messages first if exists + try: + pipeline_status = _shared_dicts.get("pipeline_status", {}) + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].clear() + except Exception: + pass # Ignore any errors during history messages cleanup + _shared_dicts.clear() + if _init_flags is not None: + _init_flags.clear() + if _update_flags is not None: + # Clear each namespace's update flags list and Value objects + try: + for namespace in _update_flags: + flags_list = _update_flags[namespace] + if isinstance(flags_list, list): + # Clear Value objects in the list + for flag in flags_list: + if hasattr( + flag, "value" + ): # Check if it's a Value object + flag.value = False + flags_list.clear() + except Exception: + pass # Ignore any errors during update flags cleanup + _update_flags.clear() + + # Shut down the Manager - this will automatically clean up all shared resources + _manager.shutdown() + direct_log(f"Process {os.getpid()} Manager shutdown complete") + except Exception as e: + direct_log( + f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR" + ) + + # Reset global variables + _manager = None + _initialized = None + is_multiprocess = None + _shared_dicts = None + _init_flags = None + _storage_lock = None + _internal_lock = None + _pipeline_status_lock = None + _update_flags = None + + direct_log(f"Process {os.getpid()} storage data finalization complete") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 805da1a2..208bdf3e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -45,7 +45,6 @@ from .utils import ( lazy_external_import, limit_async_func_call, logger, - set_logger, ) from .types import KnowledgeGraph from dotenv import load_dotenv @@ -268,9 +267,14 @@ class LightRAG: def __post_init__(self): os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) - set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") + from lightrag.kg.shared_storage import ( + initialize_share_data, + ) + + initialize_share_data() + if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) @@ -692,117 +696,221 @@ class LightRAG: 3. Process each chunk for entity and relation extraction 4. Update the document status """ - # 1. Get all pending, failed, and abnormally terminated processing documents. - # Run the asynchronous status retrievals in parallel using asyncio.gather - processing_docs, failed_docs, pending_docs = await asyncio.gather( - self.doc_status.get_docs_by_status(DocStatus.PROCESSING), - self.doc_status.get_docs_by_status(DocStatus.FAILED), - self.doc_status.get_docs_by_status(DocStatus.PENDING), + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_pipeline_status_lock, ) - to_process_docs: dict[str, DocProcessingStatus] = {} - to_process_docs.update(processing_docs) - to_process_docs.update(failed_docs) - to_process_docs.update(pending_docs) + # Get pipeline status shared data and lock + pipeline_status = await get_namespace_data("pipeline_status") + pipeline_status_lock = get_pipeline_status_lock() - if not to_process_docs: - logger.info("All documents have been processed or are duplicates") - return + # Check if another process is already processing the queue + async with pipeline_status_lock: + # Ensure only one worker is processing documents + if not pipeline_status.get("busy", False): + # ๅ…ˆๆฃ€ๆŸฅๆ˜ฏๅฆๆœ‰้œ€่ฆๅค„็†็š„ๆ–‡ๆกฃ + processing_docs, failed_docs, pending_docs = await asyncio.gather( + self.doc_status.get_docs_by_status(DocStatus.PROCESSING), + self.doc_status.get_docs_by_status(DocStatus.FAILED), + self.doc_status.get_docs_by_status(DocStatus.PENDING), + ) - # 2. split docs into chunks, insert chunks, update doc status - docs_batches = [ - list(to_process_docs.items())[i : i + self.max_parallel_insert] - for i in range(0, len(to_process_docs), self.max_parallel_insert) - ] + to_process_docs: dict[str, DocProcessingStatus] = {} + to_process_docs.update(processing_docs) + to_process_docs.update(failed_docs) + to_process_docs.update(pending_docs) - logger.info(f"Number of batches to process: {len(docs_batches)}.") + # ๅฆ‚ๆžœๆฒกๆœ‰้œ€่ฆๅค„็†็š„ๆ–‡ๆกฃ๏ผŒ็›ดๆŽฅ่ฟ”ๅ›ž๏ผŒไฟ็•™ pipeline_status ไธญ็š„ๅ†…ๅฎนไธๅ˜ + if not to_process_docs: + logger.info("No documents to process") + return - batches: list[Any] = [] - # 3. iterate over batches - for batch_idx, docs_batch in enumerate(docs_batches): - - async def batch( - batch_idx: int, - docs_batch: list[tuple[str, DocProcessingStatus]], - size_batch: int, - ) -> None: - logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.") - # 4. iterate over batch - for doc_id_processing_status in docs_batch: - doc_id, status_doc = doc_id_processing_status - # Generate chunks from document - chunks: dict[str, Any] = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_id, - } - for dp in self.chunking_func( - status_doc.content, - split_by_character, - split_by_character_only, - self.chunk_overlap_token_size, - self.chunk_token_size, - self.tiktoken_model_name, - ) + # ๆœ‰ๆ–‡ๆกฃ้œ€่ฆๅค„็†๏ผŒๆ›ดๆ–ฐ pipeline_status + pipeline_status.update( + { + "busy": True, + "job_name": "indexing files", + "job_start": datetime.now().isoformat(), + "docs": 0, + "batchs": 0, + "cur_batch": 0, + "request_pending": False, # Clear any previous request + "latest_message": "", } - # Process document (text chunks and full docs) in parallel - tasks = [ - self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSING, - "updated_at": datetime.now().isoformat(), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - } - } - ), - self.chunks_vdb.upsert(chunks), - self._process_entity_relation_graph(chunks), - self.full_docs.upsert( - {doc_id: {"content": status_doc.content}} - ), - self.text_chunks.upsert(chunks), - ] - try: - await asyncio.gather(*tasks) - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - } - } - ) - except Exception as e: - logger.error(f"Failed to process document {doc_id}: {str(e)}") - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.FAILED, - "error": str(e), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - } - } - ) - continue - logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.") + ) + # Cleaning history_messages without breaking it as a shared list object + del pipeline_status["history_messages"][:] + else: + # Another process is busy, just set request flag and return + pipeline_status["request_pending"] = True + logger.info( + "Another process is already processing the document queue. Request queued." + ) + return - batches.append(batch(batch_idx, docs_batch, len(docs_batches))) + try: + # Process documents until no more documents or requests + while True: + if not to_process_docs: + log_message = "All documents have been processed or are duplicates" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + break - await asyncio.gather(*batches) - await self._insert_done() + # 2. split docs into chunks, insert chunks, update doc status + docs_batches = [ + list(to_process_docs.items())[i : i + self.max_parallel_insert] + for i in range(0, len(to_process_docs), self.max_parallel_insert) + ] + + log_message = f"Number of batches to process: {len(docs_batches)}." + logger.info(log_message) + + # Update pipeline status with current batch information + pipeline_status["docs"] += len(to_process_docs) + pipeline_status["batchs"] += len(docs_batches) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + batches: list[Any] = [] + # 3. iterate over batches + for batch_idx, docs_batch in enumerate(docs_batches): + # Update current batch in pipeline status (directly, as it's atomic) + pipeline_status["cur_batch"] += 1 + + async def batch( + batch_idx: int, + docs_batch: list[tuple[str, DocProcessingStatus]], + size_batch: int, + ) -> None: + log_message = ( + f"Start processing batch {batch_idx + 1} of {size_batch}." + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + # 4. iterate over batch + for doc_id_processing_status in docs_batch: + doc_id, status_doc = doc_id_processing_status + # Generate chunks from document + chunks: dict[str, Any] = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, + } + for dp in self.chunking_func( + status_doc.content, + split_by_character, + split_by_character_only, + self.chunk_overlap_token_size, + self.chunk_token_size, + self.tiktoken_model_name, + ) + } + # Process document (text chunks and full docs) in parallel + tasks = [ + self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSING, + "updated_at": datetime.now().isoformat(), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + } + } + ), + self.chunks_vdb.upsert(chunks), + self._process_entity_relation_graph(chunks), + self.full_docs.upsert( + {doc_id: {"content": status_doc.content}} + ), + self.text_chunks.upsert(chunks), + ] + try: + await asyncio.gather(*tasks) + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } + } + ) + except Exception as e: + logger.error( + f"Failed to process document {doc_id}: {str(e)}" + ) + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } + } + ) + continue + log_message = ( + f"Completed batch {batch_idx + 1} of {len(docs_batches)}." + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + batches.append(batch(batch_idx, docs_batch, len(docs_batches))) + + await asyncio.gather(*batches) + await self._insert_done() + + # Check if there's a pending request to process more documents (with lock) + has_pending_request = False + async with pipeline_status_lock: + has_pending_request = pipeline_status.get("request_pending", False) + if has_pending_request: + # Clear the request flag before checking for more documents + pipeline_status["request_pending"] = False + + if not has_pending_request: + break + + log_message = "Processing additional documents due to pending request" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # ่Žทๅ–ๆ–ฐ็š„ๅพ…ๅค„็†ๆ–‡ๆกฃ + processing_docs, failed_docs, pending_docs = await asyncio.gather( + self.doc_status.get_docs_by_status(DocStatus.PROCESSING), + self.doc_status.get_docs_by_status(DocStatus.FAILED), + self.doc_status.get_docs_by_status(DocStatus.PENDING), + ) + + to_process_docs = {} + to_process_docs.update(processing_docs) + to_process_docs.update(failed_docs) + to_process_docs.update(pending_docs) + + finally: + log_message = "Document processing pipeline completed" + logger.info(log_message) + # Always reset busy status when done or if an exception occurs (with lock) + async with pipeline_status_lock: + pipeline_status["busy"] = False + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: @@ -833,7 +941,16 @@ class LightRAG: if storage_inst is not None ] await asyncio.gather(*tasks) - logger.info("All Insert done") + + log_message = "All Insert done" + logger.info(log_message) + + # ่Žทๅ– pipeline_status ๅนถๆ›ดๆ–ฐ latest_message ๅ’Œ history_messages + from lightrag.kg.shared_storage import get_namespace_data + + pipeline_status = await get_namespace_data("pipeline_status") + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None: loop = always_get_an_event_loop() diff --git a/lightrag/operate.py b/lightrag/operate.py index b52d1ef6..7db42284 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -339,6 +339,9 @@ async def extract_entities( global_config: dict[str, str], llm_response_cache: BaseKVStorage | None = None, ) -> None: + from lightrag.kg.shared_storage import get_namespace_data + + pipeline_status = await get_namespace_data("pipeline_status") use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ @@ -499,9 +502,10 @@ async def extract_entities( processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) - logger.info( - f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)" - ) + log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) return dict(maybe_nodes), dict(maybe_edges) tasks = [_process_single_content(c) for c in ordered_chunks] @@ -530,17 +534,27 @@ async def extract_entities( ) if not (all_entities_data or all_relationships_data): - logger.info("Didn't extract any entities and relationships.") + log_message = "Didn't extract any entities and relationships." + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) return if not all_entities_data: - logger.info("Didn't extract any entities") + log_message = "Didn't extract any entities" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) if not all_relationships_data: - logger.info("Didn't extract any relationships") + log_message = "Didn't extract any relationships" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - logger.info( - f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)" - ) + log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) verbose_debug( f"New entities:{all_entities_data}, relationships:{all_relationships_data}" ) diff --git a/lightrag/utils.py b/lightrag/utils.py index e7217def..c86ad8c0 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -56,6 +56,18 @@ def set_verbose_debug(enabled: bool): VERBOSE_DEBUG = enabled +statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} + +# Initialize logger +logger = logging.getLogger("lightrag") +logger.propagate = False # prevent log message send to root loggger +# Let the main application configure the handlers +logger.setLevel(logging.INFO) + +# Set httpx logging level to WARNING +logging.getLogger("httpx").setLevel(logging.WARNING) + + class UnlimitedSemaphore: """A context manager that allows unlimited access.""" @@ -68,34 +80,6 @@ class UnlimitedSemaphore: ENCODER = None -statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} - -logger = logging.getLogger("lightrag") - -# Set httpx logging level to WARNING -logging.getLogger("httpx").setLevel(logging.WARNING) - - -def set_logger(log_file: str, level: int = logging.DEBUG): - """Set up file logging with the specified level. - - Args: - log_file: Path to the log file - level: Logging level (e.g. logging.DEBUG, logging.INFO) - """ - logger.setLevel(level) - - file_handler = logging.FileHandler(log_file, encoding="utf-8") - file_handler.setLevel(level) - - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - file_handler.setFormatter(formatter) - - if not logger.handlers: - logger.addHandler(file_handler) - @dataclass class EmbeddingFunc: diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py new file mode 100755 index 00000000..2e4e3cf7 --- /dev/null +++ b/run_with_gunicorn.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +""" +Start LightRAG server with Gunicorn +""" + +import os +import sys +import signal +import pipmaster as pm +from lightrag.api.utils_api import parse_args, display_splash_screen +from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data + + +def check_and_install_dependencies(): + """Check and install required dependencies""" + required_packages = [ + "gunicorn", + "tiktoken", + "psutil", + # Add other required packages here + ] + + for package in required_packages: + if not pm.is_installed(package): + print(f"Installing {package}...") + pm.install(package) + print(f"{package} installed successfully") + + +# Signal handler for graceful shutdown +def signal_handler(sig, frame): + print("\n\n" + "=" * 80) + print("RECEIVED TERMINATION SIGNAL") + print(f"Process ID: {os.getpid()}") + print("=" * 80 + "\n") + + # Release shared resources + finalize_share_data() + + # Exit with success status + sys.exit(0) + + +def main(): + # Check and install dependencies + check_and_install_dependencies() + + # Register signal handlers for graceful shutdown + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # kill command + + # Parse all arguments using parse_args + args = parse_args(is_uvicorn_mode=False) + + # Display startup information + display_splash_screen(args) + + print("๐Ÿš€ Starting LightRAG with Gunicorn") + print(f"๐Ÿ”„ Worker management: Gunicorn (workers={args.workers})") + print("๐Ÿ” Preloading app: Enabled") + print("๐Ÿ“ Note: Using Gunicorn's preload feature for shared data initialization") + print("\n\n" + "=" * 80) + print("MAIN PROCESS INITIALIZATION") + print(f"Process ID: {os.getpid()}") + print(f"Workers setting: {args.workers}") + print("=" * 80 + "\n") + + # Import Gunicorn's StandaloneApplication + from gunicorn.app.base import BaseApplication + + # Define a custom application class that loads our config + class GunicornApp(BaseApplication): + def __init__(self, app, options=None): + self.options = options or {} + self.application = app + super().__init__() + + def load_config(self): + # Define valid Gunicorn configuration options + valid_options = { + "bind", + "workers", + "worker_class", + "timeout", + "keepalive", + "preload_app", + "errorlog", + "accesslog", + "loglevel", + "certfile", + "keyfile", + "limit_request_line", + "limit_request_fields", + "limit_request_field_size", + "graceful_timeout", + "max_requests", + "max_requests_jitter", + } + + # Special hooks that need to be set separately + special_hooks = { + "on_starting", + "on_reload", + "on_exit", + "pre_fork", + "post_fork", + "pre_exec", + "pre_request", + "post_request", + "worker_init", + "worker_exit", + "nworkers_changed", + "child_exit", + } + + # Import and configure the gunicorn_config module + import gunicorn_config + + # Set configuration variables in gunicorn_config, prioritizing command line arguments + gunicorn_config.workers = ( + args.workers if args.workers else int(os.getenv("WORKERS", 1)) + ) + + # Bind configuration prioritizes command line arguments + host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") + port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) + gunicorn_config.bind = f"{host}:{port}" + + # Log level configuration prioritizes command line arguments + gunicorn_config.loglevel = ( + args.log_level.lower() + if args.log_level + else os.getenv("LOG_LEVEL", "info") + ) + + # Timeout configuration prioritizes command line arguments + gunicorn_config.timeout = ( + args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) + ) + + # Keepalive configuration + gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) + + # SSL configuration prioritizes command line arguments + if args.ssl or os.getenv("SSL", "").lower() in ( + "true", + "1", + "yes", + "t", + "on", + ): + gunicorn_config.certfile = ( + args.ssl_certfile + if args.ssl_certfile + else os.getenv("SSL_CERTFILE") + ) + gunicorn_config.keyfile = ( + args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") + ) + + # Set configuration options from the module + for key in dir(gunicorn_config): + if key in valid_options: + value = getattr(gunicorn_config, key) + # Skip functions like on_starting and None values + if not callable(value) and value is not None: + self.cfg.set(key, value) + # Set special hooks + elif key in special_hooks: + value = getattr(gunicorn_config, key) + if callable(value): + self.cfg.set(key, value) + + if hasattr(gunicorn_config, "logconfig_dict"): + self.cfg.set( + "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") + ) + + def load(self): + # Import the application + from lightrag.api.lightrag_server import get_application + + return get_application(args) + + # Create the application + app = GunicornApp("") + + # Force workers to be an integer and greater than 1 for multi-process mode + workers_count = int(args.workers) + if workers_count > 1: + # Set a flag to indicate we're in the main process + os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" + initialize_share_data(workers_count) + else: + initialize_share_data(1) + + # Run the application + print("\nStarting Gunicorn with direct Python API...") + app.run() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index c190bd4d..b9063d7d 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ setuptools.setup( entry_points={ "console_scripts": [ "lightrag-server=lightrag.api.lightrag_server:main [api]", + "lightrag-gunicorn=lightrag.api.run_with_gunicorn:main [api]", "lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]", ], },