Merge pull request #969 from danielaskdd/add-multi-worker-support

Add multi workers support for API Server
This commit is contained in:
Yannick Stephan
2025-03-02 17:47:00 +01:00
committed by GitHub
20 changed files with 1927 additions and 455 deletions

1
.gitignore vendored
View File

@@ -21,6 +21,7 @@ site/
# Logs / Reports
*.log
*.log.*
*.logfire
*.coverage/
log/

View File

@@ -1,6 +1,9 @@
### This is sample file of .env
### Server Configuration
# HOST=0.0.0.0
# PORT=9621
# WORKERS=1
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
@@ -22,6 +25,9 @@
### Logging level
# LOG_LEVEL=INFO
# VERBOSE=False
# LOG_DIR=/path/to/log/directory # Log file directory path, defaults to current working directory
# LOG_MAX_BYTES=10485760 # Log file max size in bytes, defaults to 10MB
# LOG_BACKUP_COUNT=5 # Number of backup files to keep, defaults to 5
### Max async calls for LLM
# MAX_ASYNC=4
@@ -138,3 +144,6 @@ MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
### Qdrant
QDRANT_URL=http://localhost:16333
# QDRANT_API_KEY=your-api-key
### Redis
REDIS_URI=redis://localhost:6379

View File

@@ -24,6 +24,8 @@ pip install -e ".[api]"
### Starting API Server with Default Settings
After installing LightRAG with API support, you can start LightRAG by this command: `lightrag-server`
LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends:
* ollama
@@ -92,10 +94,43 @@ LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai-ollama
LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai
# start with ollama llm and ollama embedding (no apikey is needed)
Light_server --llm-binding ollama --embedding-binding ollama
light-server --llm-binding ollama --embedding-binding ollama
```
### Starting API Server with Gunicorn (Production)
For production deployments, it's recommended to use Gunicorn as the WSGI server to handle concurrent requests efficiently. LightRAG provides a dedicated Gunicorn startup script that handles shared data initialization, process management, and other critical functionalities.
```bash
# Start with lightrag-gunicorn command
lightrag-gunicorn --workers 4
# Alternatively, you can use the module directly
python -m lightrag.api.run_with_gunicorn --workers 4
```
The `--workers` parameter is crucial for performance:
- Determines how many worker processes Gunicorn will spawn to handle requests
- Each worker can handle concurrent requests using asyncio
- Recommended value is (2 x number_of_cores) + 1
- For example, on a 4-core machine, use 9 workers: (2 x 4) + 1 = 9
- Consider your server's memory when setting this value, as each worker consumes memory
Other important startup parameters:
- `--host`: Server listening address (default: 0.0.0.0)
- `--port`: Server listening port (default: 9621)
- `--timeout`: Request handling timeout (default: 150 seconds)
- `--log-level`: Logging level (default: INFO)
- `--ssl`: Enable HTTPS
- `--ssl-certfile`: Path to SSL certificate file
- `--ssl-keyfile`: Path to SSL private key file
The command line parameters and enviroment variable run_with_gunicorn.py is exactly the same as `light-server`.
### For Azure OpenAI Backend
Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
```bash
# Change the resource group name, location and OpenAI resource name as needed
@@ -186,7 +221,7 @@ LightRAG supports binding to various LLM/Embedding backends:
* openai & openai compatible
* azure_openai
Use environment variables `LLM_BINDING ` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING ` or CLI argument `--embedding-binding` to select LLM backend type.
Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type.
### Storage Types Supported

View File

@@ -0,0 +1,187 @@
# gunicorn_config.py
import os
import logging
from lightrag.kg.shared_storage import finalize_share_data
from lightrag.api.lightrag_server import LightragPathFilter
# Get log directory path from environment variable
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
# These variables will be set by run_with_gunicorn.py
workers = None
bind = None
loglevel = None
certfile = None
keyfile = None
# Enable preload_app option
preload_app = True
# Use Uvicorn worker
worker_class = "uvicorn.workers.UvicornWorker"
# Other Gunicorn configurations
timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py
keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s
# Logging configuration
errorlog = os.getenv("ERROR_LOG", log_file_path) # Default write to lightrag.log
accesslog = os.getenv("ACCESS_LOG", log_file_path) # Default write to lightrag.log
logconfig_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "standard",
"stream": "ext://sys.stdout",
},
"file": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "standard",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf8",
},
},
"filters": {
"path_filter": {
"()": "lightrag.api.lightrag_server.LightragPathFilter",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": loglevel.upper() if loglevel else "INFO",
"propagate": False,
},
"gunicorn": {
"handlers": ["console", "file"],
"level": loglevel.upper() if loglevel else "INFO",
"propagate": False,
},
"gunicorn.error": {
"handlers": ["console", "file"],
"level": loglevel.upper() if loglevel else "INFO",
"propagate": False,
},
"gunicorn.access": {
"handlers": ["console", "file"],
"level": loglevel.upper() if loglevel else "INFO",
"propagate": False,
"filters": ["path_filter"],
},
},
}
def on_starting(server):
"""
Executed when Gunicorn starts, before forking the first worker processes
You can use this function to do more initialization tasks for all processes
"""
print("=" * 80)
print(f"GUNICORN MASTER PROCESS: on_starting jobs for {workers} worker(s)")
print(f"Process ID: {os.getpid()}")
print("=" * 80)
# Memory usage monitoring
try:
import psutil
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
msg = (
f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB"
)
print(msg)
except ImportError:
print("psutil not installed, skipping memory usage reporting")
print("Gunicorn initialization complete, forking workers...\n")
def on_exit(server):
"""
Executed when Gunicorn is shutting down.
This is a good place to release shared resources.
"""
print("=" * 80)
print("GUNICORN MASTER PROCESS: Shutting down")
print(f"Process ID: {os.getpid()}")
print("=" * 80)
# Release shared resources
finalize_share_data()
print("=" * 80)
print("Gunicorn shutdown complete")
print("=" * 80)
def post_fork(server, worker):
"""
Executed after a worker has been forked.
This is a good place to set up worker-specific configurations.
"""
# Configure formatters
detailed_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False):
"""Set up a logger with console and file handlers"""
logger_instance = logging.getLogger(logger_name)
logger_instance.setLevel(level)
logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
# Add path filter if requested
if add_filter:
path_filter = LightragPathFilter()
logger_instance.addFilter(path_filter)
# Set up main loggers
log_level = loglevel.upper() if loglevel else "INFO"
setup_logger("uvicorn", log_level)
setup_logger("uvicorn.access", log_level, add_filter=True)
setup_logger("lightrag", log_level, add_filter=True)
# Set up lightrag submodule loggers
for name in logging.root.manager.loggerDict:
if name.startswith("lightrag."):
setup_logger(name, log_level, add_filter=True)
# Disable uvicorn.error logger
uvicorn_error_logger = logging.getLogger("uvicorn.error")
uvicorn_error_logger.handlers = []
uvicorn_error_logger.setLevel(logging.CRITICAL)
uvicorn_error_logger.propagate = False

View File

@@ -8,11 +8,12 @@ from fastapi import (
)
from fastapi.responses import FileResponse
import asyncio
import threading
import os
from fastapi.staticfiles import StaticFiles
import logging
from typing import Dict
import logging.config
import uvicorn
import pipmaster as pm
from fastapi.staticfiles import StaticFiles
from pathlib import Path
import configparser
from ascii_colors import ASCIIColors
@@ -29,7 +30,6 @@ from lightrag import LightRAG
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from lightrag.utils import logger
from .routers.document_routes import (
DocumentManager,
create_document_routes,
@@ -39,33 +39,25 @@ from .routers.query_routes import create_query_routes
from .routers.graph_routes import create_graph_routes
from .routers.ollama_api import OllamaAPI
from lightrag.utils import logger, set_verbose_debug
from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
initialize_pipeline_status,
get_all_update_flags_status,
)
# Load environment variables
try:
load_dotenv(override=True)
except Exception as e:
logger.warning(f"Failed to load .env file: {e}")
load_dotenv(override=True)
# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")
# Global configuration
global_top_k = 60 # default value
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
class LightragPathFilter(logging.Filter):
"""Filter for lightrag logger to filter out frequent path access logs"""
# Lock for thread-safe operations
progress_lock = threading.Lock()
class AccessLogFilter(logging.Filter):
def __init__(self):
super().__init__()
# Define paths to be filtered
@@ -73,17 +65,18 @@ class AccessLogFilter(logging.Filter):
def filter(self, record):
try:
# Check if record has the required attributes for an access log
if not hasattr(record, "args") or not isinstance(record.args, tuple):
return True
if len(record.args) < 5:
return True
# Extract method, path and status from the record args
method = record.args[1]
path = record.args[2]
status = record.args[4]
# print(f"Debug - Method: {method}, Path: {path}, Status: {status}")
# print(f"Debug - Filtered paths: {self.filtered_paths}")
# Filter out successful GET requests to filtered paths
if (
method == "GET"
and (status == 200 or status == 304)
@@ -92,19 +85,14 @@ class AccessLogFilter(logging.Filter):
return False
return True
except Exception:
# In case of any error, let the message through
return True
def create_app(args):
# Set global top_k
global global_top_k
global_top_k = args.top_k # save top_k from args
# Initialize verbose debug setting
from lightrag.utils import set_verbose_debug
# Setup logging
logger.setLevel(args.log_level)
set_verbose_debug(args.verbose)
# Verify that bindings are correctly setup
@@ -138,11 +126,6 @@ def create_app(args):
if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
@@ -158,28 +141,23 @@ def create_app(args):
try:
# Initialize database connections
await rag.initialize_storages()
await initialize_pipeline_status()
# Auto scan documents if enabled
if args.auto_scan_at_startup:
# Start scanning in background
with progress_lock:
if not scan_progress["is_scanning"]:
scan_progress["is_scanning"] = True
scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0
# Create background task
task = asyncio.create_task(
run_scanning_process(rag, doc_manager)
)
app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard)
ASCIIColors.info(
f"Started background scanning of documents from {args.input_dir}"
)
else:
ASCIIColors.info(
"Skip document scanning(another scanning is active)"
)
# Check if a task is already running (with lock protection)
pipeline_status = await get_namespace_data("pipeline_status")
should_start_task = False
async with get_pipeline_status_lock():
if not pipeline_status.get("busy", False):
should_start_task = True
# Only start the task if no other task is running
if should_start_task:
# Create background task
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard)
logger.info("Auto scan task started at startup.")
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
@@ -398,6 +376,9 @@ def create_app(args):
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
# Get update flags status for all namespaces
update_status = await get_all_update_flags_status()
return {
"status": "healthy",
"working_directory": str(args.working_dir),
@@ -417,6 +398,7 @@ def create_app(args):
"graph_storage": args.graph_storage,
"vector_storage": args.vector_storage,
},
"update_status": update_status,
}
# Webui mount webui/index.html
@@ -435,12 +417,30 @@ def create_app(args):
return app
def main():
args = parse_args()
import uvicorn
import logging.config
def get_application(args=None):
"""Factory function for creating the FastAPI application"""
if args is None:
args = parse_args()
return create_app(args)
def configure_logging():
"""Configure logging for uvicorn startup"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger = logging.getLogger(logger_name)
logger.handlers = []
logger.filters = []
# Get log directory path from environment variable
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
# Configure uvicorn logging
logging.config.dictConfig(
{
"version": 1,
@@ -449,36 +449,106 @@ def main():
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"default": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
"uvicorn.access": {
"handlers": ["default"],
# Configure all uvicorn related loggers
"uvicorn": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
"uvicorn.access": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
"uvicorn.error": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
},
"filters": {
"path_filter": {
"()": "lightrag.api.lightrag_server.LightragPathFilter",
},
},
}
)
# Add filter to uvicorn access logger
uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_access_logger.addFilter(AccessLogFilter())
app = create_app(args)
def check_and_install_dependencies():
"""Check and install required dependencies"""
required_packages = [
"uvicorn",
"tiktoken",
"fastapi",
# Add other required packages here
]
for package in required_packages:
if not pm.is_installed(package):
print(f"Installing {package}...")
pm.install(package)
print(f"{package} installed successfully")
def main():
# Check if running under Gunicorn
if "GUNICORN_CMD_ARGS" in os.environ:
# If started with Gunicorn, return directly as Gunicorn will call get_application
print("Running under Gunicorn - worker management handled by Gunicorn")
return
# Check and install dependencies
check_and_install_dependencies()
from multiprocessing import freeze_support
freeze_support()
# Configure logging before parsing args
configure_logging()
args = parse_args(is_uvicorn_mode=True)
display_splash_screen(args)
# Create application instance directly instead of using factory function
app = create_app(args)
# Start Uvicorn in single process mode
uvicorn_config = {
"app": app,
"app": app, # Pass application instance directly instead of string path
"host": args.host,
"port": args.port,
"log_config": None, # Disable default config
}
if args.ssl:
uvicorn_config.update(
{
@@ -486,6 +556,8 @@ def main():
"ssl_keyfile": args.ssl_keyfile,
}
)
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
uvicorn.run(**uvicorn_config)

View File

@@ -3,8 +3,7 @@ This module contains all document-related routes for the LightRAG API.
"""
import asyncio
import logging
import os
from lightrag.utils import logger
import aiofiles
import shutil
import traceback
@@ -12,7 +11,6 @@ import pipmaster as pm
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from pydantic import BaseModel, Field, field_validator
@@ -23,18 +21,6 @@ from ..utils_api import get_api_key_dependency
router = APIRouter(prefix="/documents", tags=["documents"])
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = asyncio.Lock()
# Temporary file prefix
temp_prefix = "__tmp__"
@@ -161,19 +147,12 @@ class DocumentManager:
"""Scan input directory for new files"""
new_files = []
for ext in self.supported_extensions:
logging.debug(f"Scanning for {ext} files in {self.input_dir}")
logger.debug(f"Scanning for {ext} files in {self.input_dir}")
for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files:
new_files.append(file_path)
return new_files
# def scan_directory(self) -> List[Path]:
# new_files = []
# for ext in self.supported_extensions:
# for file_path in self.input_dir.rglob(f"*{ext}"):
# new_files.append(file_path)
# return new_files
def mark_as_indexed(self, file_path: Path):
self.indexed_files.add(file_path)
@@ -287,7 +266,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
)
content += "\n"
case _:
logging.error(
logger.error(
f"Unsupported file type: {file_path.name} (extension {ext})"
)
return False
@@ -295,20 +274,20 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
# Insert into the RAG queue
if content:
await rag.apipeline_enqueue_documents(content)
logging.info(f"Successfully fetched and enqueued file: {file_path.name}")
logger.info(f"Successfully fetched and enqueued file: {file_path.name}")
return True
else:
logging.error(f"No content could be extracted from file: {file_path.name}")
logger.error(f"No content could be extracted from file: {file_path.name}")
except Exception as e:
logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
logger.error(traceback.format_exc())
finally:
if file_path.name.startswith(temp_prefix):
try:
file_path.unlink()
except Exception as e:
logging.error(f"Error deleting file {file_path}: {str(e)}")
logger.error(f"Error deleting file {file_path}: {str(e)}")
return False
@@ -324,8 +303,8 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path):
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error indexing file {file_path.name}: {str(e)}")
logger.error(traceback.format_exc())
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
@@ -349,8 +328,8 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
if enqueued:
await rag.apipeline_process_enqueue_documents()
except Exception as e:
logging.error(f"Error indexing files: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error indexing files: {str(e)}")
logger.error(traceback.format_exc())
async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
@@ -393,30 +372,17 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
"""Background task to scan and index documents"""
try:
new_files = doc_manager.scan_directory_for_new_files()
scan_progress["total_files"] = len(new_files)
total_files = len(new_files)
logger.info(f"Found {total_files} new files to index.")
logging.info(f"Found {len(new_files)} new files to index.")
for file_path in new_files:
for idx, file_path in enumerate(new_files):
try:
async with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
await pipeline_index_file(rag, file_path)
async with progress_lock:
scan_progress["indexed_count"] += 1
scan_progress["progress"] = (
scan_progress["indexed_count"] / scan_progress["total_files"]
) * 100
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
logger.error(f"Error indexing file {file_path}: {str(e)}")
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
async with progress_lock:
scan_progress["is_scanning"] = False
logger.error(f"Error during scanning process: {str(e)}")
def create_document_routes(
@@ -436,34 +402,10 @@ def create_document_routes(
Returns:
dict: A dictionary containing the scanning status
"""
async with progress_lock:
if scan_progress["is_scanning"]:
return {"status": "already_scanning"}
scan_progress["is_scanning"] = True
scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0
# Start the scanning process in the background
background_tasks.add_task(run_scanning_process, rag, doc_manager)
return {"status": "scanning_started"}
@router.get("/scan-progress")
async def get_scan_progress():
"""
Get the current progress of the document scanning process.
Returns:
dict: A dictionary containing the current scanning progress information including:
- is_scanning: Whether a scan is currently in progress
- current_file: The file currently being processed
- indexed_count: Number of files indexed so far
- total_files: Total number of files to process
- progress: Percentage of completion
"""
async with progress_lock:
return scan_progress
@router.post("/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...)
@@ -504,8 +446,8 @@ def create_document_routes(
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/upload: {file.filename}: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post(
@@ -537,8 +479,8 @@ def create_document_routes(
message="Text successfully received. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error /documents/text: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post(
@@ -572,8 +514,8 @@ def create_document_routes(
message="Text successfully received. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error /documents/text: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post(
@@ -615,8 +557,8 @@ def create_document_routes(
message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
)
except Exception as e:
logging.error(f"Error /documents/file: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error /documents/file: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post(
@@ -678,8 +620,8 @@ def create_document_routes(
return InsertResponse(status=status, message=status_message)
except Exception as e:
logging.error(f"Error /documents/batch: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error /documents/batch: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
@@ -706,8 +648,42 @@ def create_document_routes(
status="success", message="All documents cleared successfully"
)
except Exception as e:
logging.error(f"Error DELETE /documents: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error DELETE /documents: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.get("/pipeline_status", dependencies=[Depends(optional_api_key)])
async def get_pipeline_status():
"""
Get the current status of the document indexing pipeline.
This endpoint returns information about the current state of the document processing pipeline,
including whether it's busy, the current job name, when it started, how many documents
are being processed, how many batches there are, and which batch is currently being processed.
Returns:
dict: A dictionary containing the pipeline status information
"""
try:
from lightrag.kg.shared_storage import get_namespace_data
pipeline_status = await get_namespace_data("pipeline_status")
# Convert to regular dict if it's a Manager.dict
status_dict = dict(pipeline_status)
# Convert history_messages to a regular list if it's a Manager.list
if "history_messages" in status_dict:
status_dict["history_messages"] = list(status_dict["history_messages"])
# Format the job_start time if it exists
if status_dict.get("job_start"):
status_dict["job_start"] = str(status_dict["job_start"])
return status_dict
except Exception as e:
logger.error(f"Error getting pipeline status: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.get("", dependencies=[Depends(optional_api_key)])
@@ -763,8 +739,8 @@ def create_document_routes(
)
return response
except Exception as e:
logging.error(f"Error GET /documents: {str(e)}")
logging.error(traceback.format_exc())
logger.error(f"Error GET /documents: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
return router

View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python
"""
Start LightRAG server with Gunicorn
"""
import os
import sys
import signal
import pipmaster as pm
from lightrag.api.utils_api import parse_args, display_splash_screen
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
def check_and_install_dependencies():
"""Check and install required dependencies"""
required_packages = [
"gunicorn",
"tiktoken",
"psutil",
# Add other required packages here
]
for package in required_packages:
if not pm.is_installed(package):
print(f"Installing {package}...")
pm.install(package)
print(f"{package} installed successfully")
# Signal handler for graceful shutdown
def signal_handler(sig, frame):
print("\n\n" + "=" * 80)
print("RECEIVED TERMINATION SIGNAL")
print(f"Process ID: {os.getpid()}")
print("=" * 80 + "\n")
# Release shared resources
finalize_share_data()
# Exit with success status
sys.exit(0)
def main():
# Check and install dependencies
check_and_install_dependencies()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # kill command
# Parse all arguments using parse_args
args = parse_args(is_uvicorn_mode=False)
# Display startup information
display_splash_screen(args)
print("🚀 Starting LightRAG with Gunicorn")
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
print("🔍 Preloading app: Enabled")
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
print("\n\n" + "=" * 80)
print("MAIN PROCESS INITIALIZATION")
print(f"Process ID: {os.getpid()}")
print(f"Workers setting: {args.workers}")
print("=" * 80 + "\n")
# Import Gunicorn's StandaloneApplication
from gunicorn.app.base import BaseApplication
# Define a custom application class that loads our config
class GunicornApp(BaseApplication):
def __init__(self, app, options=None):
self.options = options or {}
self.application = app
super().__init__()
def load_config(self):
# Define valid Gunicorn configuration options
valid_options = {
"bind",
"workers",
"worker_class",
"timeout",
"keepalive",
"preload_app",
"errorlog",
"accesslog",
"loglevel",
"certfile",
"keyfile",
"limit_request_line",
"limit_request_fields",
"limit_request_field_size",
"graceful_timeout",
"max_requests",
"max_requests_jitter",
}
# Special hooks that need to be set separately
special_hooks = {
"on_starting",
"on_reload",
"on_exit",
"pre_fork",
"post_fork",
"pre_exec",
"pre_request",
"post_request",
"worker_init",
"worker_exit",
"nworkers_changed",
"child_exit",
}
# Import and configure the gunicorn_config module
from lightrag.api import gunicorn_config
# Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1))
)
# Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = (
args.log_level.lower()
if args.log_level
else os.getenv("LOG_LEVEL", "info")
)
# Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = (
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
)
# Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in (
"true",
"1",
"yes",
"t",
"on",
):
gunicorn_config.certfile = (
args.ssl_certfile
if args.ssl_certfile
else os.getenv("SSL_CERTFILE")
)
gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
)
# Set configuration options from the module
for key in dir(gunicorn_config):
if key in valid_options:
value = getattr(gunicorn_config, key)
# Skip functions like on_starting and None values
if not callable(value) and value is not None:
self.cfg.set(key, value)
# Set special hooks
elif key in special_hooks:
value = getattr(gunicorn_config, key)
if callable(value):
self.cfg.set(key, value)
if hasattr(gunicorn_config, "logconfig_dict"):
self.cfg.set(
"logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
)
def load(self):
# Import the application
from lightrag.api.lightrag_server import get_application
return get_application(args)
# Create the application
app = GunicornApp("")
# Force workers to be an integer and greater than 1 for multi-process mode
workers_count = int(args.workers)
if workers_count > 1:
# Set a flag to indicate we're in the main process
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
initialize_share_data(workers_count)
else:
initialize_share_data(1)
# Run the application
print("\nStarting Gunicorn with direct Python API...")
app.run()
if __name__ == "__main__":
main()

View File

@@ -6,6 +6,7 @@ import os
import argparse
from typing import Optional
import sys
import logging
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__
from fastapi import HTTPException, Security
@@ -110,10 +111,13 @@ def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
return default
def parse_args() -> argparse.Namespace:
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
"""
Parse command line arguments with environment variable fallback
Args:
is_uvicorn_mode: Whether running under uvicorn mode
Returns:
argparse.Namespace: Parsed arguments
"""
@@ -260,6 +264,14 @@ def parse_args() -> argparse.Namespace:
help="Enable automatic scanning when the program starts",
)
# Server workers configuration
parser.add_argument(
"--workers",
type=int,
default=get_env_value("WORKERS", 1, int),
help="Number of worker processes (default: from env or 1)",
)
# LLM and embedding bindings
parser.add_argument(
"--llm-binding",
@@ -278,6 +290,15 @@ def parse_args() -> argparse.Namespace:
args = parser.parse_args()
# If in uvicorn mode and workers > 1, force it to 1 and log warning
if is_uvicorn_mode and args.workers > 1:
original_workers = args.workers
args.workers = 1
# Log warning directly here
logging.warning(
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
)
# convert relative path to absolute path
args.working_dir = os.path.abspath(args.working_dir)
args.input_dir = os.path.abspath(args.input_dir)
@@ -346,17 +367,27 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.host}")
ASCIIColors.white(" ├─ Port: ", end="")
ASCIIColors.yellow(f"{args.port}")
ASCIIColors.white(" ├─ Workers: ", end="")
ASCIIColors.yellow(f"{args.workers}")
ASCIIColors.white(" ├─ CORS Origins: ", end="")
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
ASCIIColors.yellow(f"{args.ssl}")
ASCIIColors.white(" └─ API Key: ", end="")
ASCIIColors.yellow("Set" if args.key else "Not Set")
if args.ssl:
ASCIIColors.white(" ├─ SSL Cert: ", end="")
ASCIIColors.yellow(f"{args.ssl_certfile}")
ASCIIColors.white(" ─ SSL Key: ", end="")
ASCIIColors.white(" ─ SSL Key: ", end="")
ASCIIColors.yellow(f"{args.ssl_keyfile}")
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
ASCIIColors.white(" ├─ Log Level: ", end="")
ASCIIColors.yellow(f"{args.log_level}")
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" ├─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
ASCIIColors.white(" └─ API Key: ", end="")
ASCIIColors.yellow("Set" if args.key else "Not Set")
# Directory Configuration
ASCIIColors.magenta("\n📂 Directory Configuration:")
@@ -415,16 +446,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.white(" └─ Document Status Storage: ", end="")
ASCIIColors.yellow(f"{args.doc_status_storage}")
ASCIIColors.magenta("\n🛠️ System Configuration:")
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
ASCIIColors.white(" ├─ Log Level: ", end="")
ASCIIColors.yellow(f"{args.log_level}")
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" └─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
# Server Status
ASCIIColors.green("\n✨ Server starting up...\n")
@@ -478,7 +499,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.cyan(""" 3. Basic Operations:
- POST /upload_document: Upload new documents to RAG
- POST /query: Query your document collection
- GET /collections: List available collections
4. Monitor the server:
- Check server logs for detailed operation information

View File

@@ -2,25 +2,25 @@ import os
import time
import asyncio
from typing import Any, final
import json
import numpy as np
from dataclasses import dataclass
import pipmaster as pm
from lightrag.utils import (
logger,
compute_mdhash_id,
)
from lightrag.base import (
BaseVectorStorage,
)
from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseVectorStorage
if not pm.is_installed("faiss"):
pm.install("faiss")
import faiss
import faiss # type: ignore
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final
@@ -55,14 +55,40 @@ class FaissVectorDBStorage(BaseVectorStorage):
# If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP.
self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta = {}
# Attempt to load an existing index + metadata from disk
self._load_faiss_index()
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_index(self):
"""Check if the shtorage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
logger.info(
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
)
# Reload data
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._index
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Insert or update vectors in the Faiss index.
@@ -113,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
)
return []
# Normalize embeddings for cosine similarity (in-place)
# Convert to float32 and normalize embeddings for cosine similarity (in-place)
embeddings = embeddings.astype(np.float32)
faiss.normalize_L2(embeddings)
# Upsert logic:
@@ -127,18 +154,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove)
await self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors
start_idx = self._index.ntotal
self._index.add(embeddings)
index = await self._get_index()
start_idx = index.ntotal
index.add(embeddings)
# Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data):
fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta[fid] = meta
self._id_to_meta.update({fid: meta})
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data]
@@ -157,7 +185,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
)
# Perform the similarity search
distances, indices = self._index.search(embedding, top_k)
index = await self._get_index()
distances, indices = index.search(embedding, top_k)
distances = distances[0]
indices = indices[0]
@@ -201,8 +230,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
to_remove.append(fid)
if to_remove:
self._remove_faiss_ids(to_remove)
logger.info(
await self._remove_faiss_ids(to_remove)
logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
@@ -223,12 +252,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations:
self._remove_faiss_ids(relations)
await self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self) -> None:
self._save_faiss_index()
# --------------------------------------------------------------------------------
# Internal helper methods
# --------------------------------------------------------------------------------
@@ -242,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
return fid
return None
def _remove_faiss_ids(self, fid_list):
async def _remove_faiss_ids(self, fid_list):
"""
Remove a list of internal Faiss IDs from the index.
Because IndexFlatIP doesn't support 'removals',
@@ -258,13 +284,14 @@ class FaissVectorDBStorage(BaseVectorStorage):
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta
# Re-init index
self._index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32)
self._index.add(arr)
async with self._storage_lock:
# Re-init index
self._index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32)
self._index.add(arr)
self._id_to_meta = new_id_to_meta
self._id_to_meta = new_id_to_meta
def _save_faiss_index(self):
"""
@@ -312,3 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.warning("Starting with an empty Faiss index.")
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
async def index_done_callback(self) -> None:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
)
async with self._storage_lock:
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._save_faiss_index()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
return False # Return error
return True # Return success

View File

@@ -12,6 +12,11 @@ from lightrag.utils import (
logger,
write_json,
)
from .shared_storage import (
get_namespace_data,
get_storage_lock,
try_initialize_namespace,
)
@final
@@ -22,26 +27,42 @@ class JsonDocStatusStorage(DocStatusStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data: dict[str, Any] = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records")
self._storage_lock = get_storage_lock()
self._data = None
async def initialize(self):
"""Initialize storage data"""
# check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace)
self._data = await get_namespace_data(self.namespace)
if need_init:
loaded_data = load_json(self._file_name) or {}
async with self._storage_lock:
self._data.update(loaded_data)
logger.info(
f"Loaded document status storage with {len(loaded_data)} records"
)
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
return set(keys) - set(self._data.keys())
async with self._storage_lock:
return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
for id in ids:
data = self._data.get(id, None)
if data:
result.append(data)
async with self._storage_lock:
for id in ids:
data = self._data.get(id, None)
if data:
result.append(data)
return result
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus}
for doc in self._data.values():
counts[doc["status"]] += 1
async with self._storage_lock:
for doc in self._data.values():
counts[doc["status"]] += 1
return counts
async def get_docs_by_status(
@@ -49,39 +70,48 @@ class JsonDocStatusStorage(DocStatusStorage):
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
result = {}
for k, v in self._data.items():
if v["status"] == status.value:
try:
# Make a copy of the data to avoid modifying the original
data = v.copy()
# If content is missing, use content_summary as content
if "content" not in data and "content_summary" in data:
data["content"] = data["content_summary"]
result[k] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}")
continue
async with self._storage_lock:
for k, v in self._data.items():
if v["status"] == status.value:
try:
# Make a copy of the data to avoid modifying the original
data = v.copy()
# If content is missing, use content_summary as content
if "content" not in data and "content_summary" in data:
data["content"] = data["content_summary"]
result[k] = DocProcessingStatus(**data)
except KeyError as e:
logger.error(f"Missing required field for document {k}: {e}")
continue
return result
async def index_done_callback(self) -> None:
write_json(self._data, self._file_name)
async with self._storage_lock:
data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
write_json(data_dict, self._file_name)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
self._data.update(data)
async with self._storage_lock:
self._data.update(data)
await self.index_done_callback()
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.get(id)
async with self._storage_lock:
return self._data.get(id)
async def delete(self, doc_ids: list[str]):
for doc_id in doc_ids:
self._data.pop(doc_id, None)
async with self._storage_lock:
for doc_id in doc_ids:
self._data.pop(doc_id, None)
await self.index_done_callback()
async def drop(self) -> None:
"""Drop the storage"""
self._data.clear()
async with self._storage_lock:
self._data.clear()

View File

@@ -1,4 +1,3 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, final
@@ -11,6 +10,11 @@ from lightrag.utils import (
logger,
write_json,
)
from .shared_storage import (
get_namespace_data,
get_storage_lock,
try_initialize_namespace,
)
@final
@@ -19,37 +23,56 @@ class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data: dict[str, Any] = load_json(self._file_name) or {}
self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
self._storage_lock = get_storage_lock()
self._data = None
async def initialize(self):
"""Initialize storage data"""
# check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace)
self._data = await get_namespace_data(self.namespace)
if need_init:
loaded_data = load_json(self._file_name) or {}
async with self._storage_lock:
self._data.update(loaded_data)
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
async def index_done_callback(self) -> None:
write_json(self._data, self._file_name)
async with self._storage_lock:
data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
write_json(data_dict, self._file_name)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return self._data.get(id)
async with self._storage_lock:
return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return [
(
{k: v for k, v in self._data[id].items()}
if self._data.get(id, None)
else None
)
for id in ids
]
async with self._storage_lock:
return [
(
{k: v for k, v in self._data[id].items()}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, keys: set[str]) -> set[str]:
return set(keys) - set(self._data.keys())
async with self._storage_lock:
return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
async with self._storage_lock:
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
async def delete(self, ids: list[str]) -> None:
for doc_id in ids:
self._data.pop(doc_id, None)
async with self._storage_lock:
for doc_id in ids:
self._data.pop(doc_id, None)
await self.index_done_callback()

View File

@@ -3,7 +3,6 @@ import os
from typing import Any, final
from dataclasses import dataclass
import numpy as np
import time
from lightrag.utils import (
@@ -11,22 +10,29 @@ from lightrag.utils import (
compute_mdhash_id,
)
import pipmaster as pm
from lightrag.base import (
BaseVectorStorage,
)
from lightrag.base import BaseVectorStorage
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self):
# Initialize lock only for file operations
self._save_lock = asyncio.Lock()
# Initialize basic attributes
self._client = None
self._storage_lock = None
self.storage_updated = None
# Use global config value if specified, otherwise use default
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -40,10 +46,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_client(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
logger.info(
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
)
# Reload data
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
@@ -64,6 +103,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
for i in range(0, len(contents), self._max_batch_size)
]
# Execute embedding outside of lock to avoid long lock times
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = await asyncio.gather(*embedding_tasks)
@@ -71,7 +111,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
client = await self._get_client()
results = client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
@@ -80,9 +121,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid long lock times
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
client = await self._get_client()
results = client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
@@ -99,8 +143,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
return results
@property
def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage")
async def client_storage(self):
client = await self._get_client()
return getattr(client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
@@ -109,8 +154,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted
"""
try:
self._client.delete(ids)
logger.info(
client = await self._get_client()
client.delete(ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
@@ -122,9 +168,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Check if the entity exists
if self._client.get([entity_id]):
await self.delete([entity_id])
client = await self._get_client()
if client.get([entity_id]):
client.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
@@ -133,16 +181,19 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def delete_entity_relation(self, entity_name: str) -> None:
try:
client = await self._get_client()
storage = getattr(client, "_NanoVectorDB__storage")
relations = [
dp
for dp in self.client_storage["data"]
for dp in storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
await self.delete(ids_to_delete)
client = await self._get_client()
client.delete(ids_to_delete)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
@@ -151,6 +202,37 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self) -> None:
async with self._save_lock:
self._client.save()
async def index_done_callback(self) -> bool:
"""Save data to disk"""
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._client.save()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving data for {self.namespace}: {e}")
return False # Return error
return True # Return success

View File

@@ -1,18 +1,12 @@
import os
from dataclasses import dataclass
from typing import Any, final
import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import (
logger,
)
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
from lightrag.base import (
BaseGraphStorage,
)
import pipmaster as pm
if not pm.is_installed("networkx"):
@@ -23,6 +17,12 @@ if not pm.is_installed("graspologic"):
import networkx as nx
from graspologic import embed
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final
@@ -78,56 +78,101 @@ class NetworkXStorage(BaseGraphStorage):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
self._storage_lock = None
self.storage_updated = None
self._graph = None
# Load initial graph
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph")
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self) -> None:
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_graph(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
logger.info(
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
)
# Reload data
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._graph
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
graph = await self._get_graph()
return graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
graph = await self._get_graph()
return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._graph.nodes.get(node_id)
graph = await self._get_graph()
return graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
return self._graph.degree(node_id)
graph = await self._get_graph()
return graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
graph = await self._get_graph()
return graph.degree(src_id) + graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
return self._graph.edges.get((source_node_id, target_node_id))
graph = await self._get_graph()
return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
graph = await self._get_graph()
if graph.has_node(source_node_id):
return list(graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._graph.add_node(node_id, **node_data)
graph = await self._get_graph()
graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
graph = await self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None:
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
graph = await self._get_graph()
if graph.has_node(node_id):
graph.remove_node(node_id)
logger.debug(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
@@ -138,35 +183,37 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
# TODO: NOT USED
async def _node2vec_embed(self):
graph = await self._get_graph()
embeddings, nodes = embed.node2vec_embed(
self._graph,
graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]):
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to be deleted
"""
graph = await self._get_graph()
for node in nodes:
if self._graph.has_node(node):
self._graph.remove_node(node)
if graph.has_node(node):
graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]):
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
graph = await self._get_graph()
for source, target in edges:
if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target)
if graph.has_edge(source, target):
graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]:
"""
@@ -174,8 +221,9 @@ class NetworkXStorage(BaseGraphStorage):
Returns:
[label1, label2, ...] # Alphabetically sorted label list
"""
graph = await self._get_graph()
labels = set()
for node in self._graph.nodes():
for node in graph.nodes():
labels.add(str(node)) # Add node id as a label
# Return sorted list
@@ -198,16 +246,18 @@ class NetworkXStorage(BaseGraphStorage):
seen_nodes = set()
seen_edges = set()
graph = await self._get_graph()
# Handle special case for "*" label
if node_label == "*":
# For "*", return the entire graph including all nodes and edges
subgraph = (
self._graph.copy()
graph.copy()
) # Create a copy to avoid modifying the original graph
else:
# Find nodes with matching node id (partial match)
nodes_to_explore = []
for n, attr in self._graph.nodes(data=True):
for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching
nodes_to_explore.append(n)
@@ -216,7 +266,7 @@ class NetworkXStorage(BaseGraphStorage):
return result
# Get subgraph using ego_graph
subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth)
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500
@@ -278,9 +328,41 @@ class NetworkXStorage(BaseGraphStorage):
)
seen_edges.add(edge_id)
# logger.info(result.edges)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def index_done_callback(self) -> bool:
"""Save data to disk"""
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Graph for {self.namespace} was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving graph for {self.namespace}: {e}")
return False # Return error
return True

View File

@@ -38,8 +38,8 @@ import pipmaster as pm
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import asyncpg
from asyncpg import Pool
import asyncpg # type: ignore
from asyncpg import Pool # type: ignore
class PostgreSQLDB:

View File

@@ -0,0 +1,374 @@
import os
import sys
import asyncio
from multiprocessing.synchronize import Lock as ProcessLock
from multiprocessing import Manager
from typing import Any, Dict, Optional, Union, TypeVar, Generic
# Define a direct print function for critical logs that must be visible in all processes
def direct_log(message, level="INFO"):
"""
Log a message directly to stderr to ensure visibility in all processes,
including the Gunicorn master process.
"""
print(f"{level}: {message}", file=sys.stderr, flush=True)
T = TypeVar("T")
LockType = Union[ProcessLock, asyncio.Lock]
is_multiprocess = None
_workers = None
_manager = None
_initialized = None
# shared data for storage across processes
_shared_dicts: Optional[Dict[str, Any]] = None
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
# locks for mutex access
_storage_lock: Optional[LockType] = None
_internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None
class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
self._lock = lock
self._is_async = is_async
async def __aenter__(self) -> "UnifiedLock[T]":
if self._is_async:
await self._lock.acquire()
else:
self._lock.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._is_async:
self._lock.release()
else:
self._lock.release()
def __enter__(self) -> "UnifiedLock[T]":
"""For backward compatibility"""
if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock")
self._lock.acquire()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""For backward compatibility"""
if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock")
self._lock.release()
def get_internal_lock() -> UnifiedLock:
"""return unified storage lock for data consistency"""
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
def get_storage_lock() -> UnifiedLock:
"""return unified storage lock for data consistency"""
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
def get_pipeline_status_lock() -> UnifiedLock:
"""return unified storage lock for data consistency"""
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
def initialize_share_data(workers: int = 1):
"""
Initialize shared storage data for single or multi-process mode.
When used with Gunicorn's preload feature, this function is called once in the
master process before forking worker processes, allowing all workers to share
the same initialized data.
In single-process mode, this function is called in FASTAPI lifespan function.
The function determines whether to use cross-process shared variables for data storage
based on the number of workers. If workers=1, it uses thread locks and local dictionaries.
If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager.
Args:
workers (int): Number of worker processes. If 1, single-process mode is used.
If > 1, multi-process mode with shared memory is used.
"""
global \
_manager, \
_workers, \
is_multiprocess, \
_storage_lock, \
_internal_lock, \
_pipeline_status_lock, \
_shared_dicts, \
_init_flags, \
_initialized, \
_update_flags
# Check if already initialized
if _initialized:
direct_log(
f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})"
)
return
_manager = Manager()
_workers = workers
if workers > 1:
is_multiprocess = True
_internal_lock = _manager.Lock()
_storage_lock = _manager.Lock()
_pipeline_status_lock = _manager.Lock()
_shared_dicts = _manager.dict()
_init_flags = _manager.dict()
_update_flags = _manager.dict()
direct_log(
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
)
else:
is_multiprocess = False
_internal_lock = asyncio.Lock()
_storage_lock = asyncio.Lock()
_pipeline_status_lock = asyncio.Lock()
_shared_dicts = {}
_init_flags = {}
_update_flags = {}
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
# Mark as initialized
_initialized = True
async def initialize_pipeline_status():
"""
Initialize pipeline namespace with default values.
This function is called during FASTAPI lifespan for each worker.
"""
pipeline_namespace = await get_namespace_data("pipeline_status")
async with get_internal_lock():
# Check if already initialized by checking for required fields
if "busy" in pipeline_namespace:
return
# Create a shared list object for history_messages
history_messages = _manager.list() if is_multiprocess else []
pipeline_namespace.update(
{
"busy": False, # Control concurrent processes
"job_name": "Default Job", # Current job name (indexing files/indexing texts)
"job_start": None, # Job start time
"docs": 0, # Total number of documents to be indexed
"batchs": 0, # Number of batches for processing documents
"cur_batch": 0, # Current processing batch
"request_pending": False, # Flag for pending request for processing
"latest_message": "", # Latest message from pipeline processing
"history_messages": history_messages, # 使用共享列表对象
}
)
direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
async def get_update_flag(namespace: str):
"""
Create a namespace's update flag for a workers.
Returen the update flag to caller for referencing or reset.
"""
global _update_flags
if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized")
async with get_internal_lock():
if namespace not in _update_flags:
if is_multiprocess and _manager is not None:
_update_flags[namespace] = _manager.list()
else:
_update_flags[namespace] = []
direct_log(
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
)
if is_multiprocess and _manager is not None:
new_update_flag = _manager.Value("b", False)
else:
new_update_flag = False
_update_flags[namespace].append(new_update_flag)
return new_update_flag
async def set_all_update_flags(namespace: str):
"""Set all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags
if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized")
async with get_internal_lock():
if namespace not in _update_flags:
raise ValueError(f"Namespace {namespace} not found in update flags")
# Update flags for both modes
for i in range(len(_update_flags[namespace])):
if is_multiprocess:
_update_flags[namespace][i].value = True
else:
_update_flags[namespace][i] = True
async def get_all_update_flags_status() -> Dict[str, list]:
"""
Get update flags status for all namespaces.
Returns:
Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses
"""
if _update_flags is None:
return {}
result = {}
async with get_internal_lock():
for namespace, flags in _update_flags.items():
worker_statuses = []
for flag in flags:
if is_multiprocess:
worker_statuses.append(flag.value)
else:
worker_statuses.append(flag)
result[namespace] = worker_statuses
return result
def try_initialize_namespace(namespace: str) -> bool:
"""
Returns True if the current worker(process) gets initialization permission for loading data later.
The worker does not get the permission is prohibited to load data from files.
"""
global _init_flags, _manager
if _init_flags is None:
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
if namespace not in _init_flags:
_init_flags[namespace] = True
direct_log(
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
)
return True
direct_log(
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
)
return False
async def get_namespace_data(namespace: str) -> Dict[str, Any]:
"""get the shared data reference for specific namespace"""
if _shared_dicts is None:
direct_log(
f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}",
level="ERROR",
)
raise ValueError("Shared dictionaries not initialized")
async with get_internal_lock():
if namespace not in _shared_dicts:
if is_multiprocess and _manager is not None:
_shared_dicts[namespace] = _manager.dict()
else:
_shared_dicts[namespace] = {}
return _shared_dicts[namespace]
def finalize_share_data():
"""
Release shared resources and clean up.
This function should be called when the application is shutting down
to properly release shared resources and avoid memory leaks.
In multi-process mode, it shuts down the Manager and releases all shared objects.
In single-process mode, it simply resets the global variables.
"""
global \
_manager, \
is_multiprocess, \
_storage_lock, \
_internal_lock, \
_pipeline_status_lock, \
_shared_dicts, \
_init_flags, \
_initialized, \
_update_flags
# Check if already initialized
if not _initialized:
direct_log(
f"Process {os.getpid()} storage data not initialized, nothing to finalize"
)
return
direct_log(
f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})"
)
# In multi-process mode, shut down the Manager
if is_multiprocess and _manager is not None:
try:
# Clear shared resources before shutting down Manager
if _shared_dicts is not None:
# Clear pipeline status history messages first if exists
try:
pipeline_status = _shared_dicts.get("pipeline_status", {})
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].clear()
except Exception:
pass # Ignore any errors during history messages cleanup
_shared_dicts.clear()
if _init_flags is not None:
_init_flags.clear()
if _update_flags is not None:
# Clear each namespace's update flags list and Value objects
try:
for namespace in _update_flags:
flags_list = _update_flags[namespace]
if isinstance(flags_list, list):
# Clear Value objects in the list
for flag in flags_list:
if hasattr(
flag, "value"
): # Check if it's a Value object
flag.value = False
flags_list.clear()
except Exception:
pass # Ignore any errors during update flags cleanup
_update_flags.clear()
# Shut down the Manager - this will automatically clean up all shared resources
_manager.shutdown()
direct_log(f"Process {os.getpid()} Manager shutdown complete")
except Exception as e:
direct_log(
f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR"
)
# Reset global variables
_manager = None
_initialized = None
is_multiprocess = None
_shared_dicts = None
_init_flags = None
_storage_lock = None
_internal_lock = None
_pipeline_status_lock = None
_update_flags = None
direct_log(f"Process {os.getpid()} storage data finalization complete")

View File

@@ -45,7 +45,6 @@ from .utils import (
lazy_external_import,
limit_async_func_call,
logger,
set_logger,
)
from .types import KnowledgeGraph
from dotenv import load_dotenv
@@ -268,9 +267,14 @@ class LightRAG:
def __post_init__(self):
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path, self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
from lightrag.kg.shared_storage import (
initialize_share_data,
)
initialize_share_data()
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
@@ -692,117 +696,221 @@ class LightRAG:
3. Process each chunk for entity and relation extraction
4. Update the document status
"""
# 1. Get all pending, failed, and abnormally terminated processing documents.
# Run the asynchronous status retrievals in parallel using asyncio.gather
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
)
to_process_docs: dict[str, DocProcessingStatus] = {}
to_process_docs.update(processing_docs)
to_process_docs.update(failed_docs)
to_process_docs.update(pending_docs)
# Get pipeline status shared data and lock
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status_lock = get_pipeline_status_lock()
if not to_process_docs:
logger.info("All documents have been processed or are duplicates")
return
# Check if another process is already processing the queue
async with pipeline_status_lock:
# Ensure only one worker is processing documents
if not pipeline_status.get("busy", False):
# 先检查是否有需要处理的文档
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
# 2. split docs into chunks, insert chunks, update doc status
docs_batches = [
list(to_process_docs.items())[i : i + self.max_parallel_insert]
for i in range(0, len(to_process_docs), self.max_parallel_insert)
]
to_process_docs: dict[str, DocProcessingStatus] = {}
to_process_docs.update(processing_docs)
to_process_docs.update(failed_docs)
to_process_docs.update(pending_docs)
logger.info(f"Number of batches to process: {len(docs_batches)}.")
# 如果没有需要处理的文档,直接返回,保留 pipeline_status 中的内容不变
if not to_process_docs:
logger.info("No documents to process")
return
batches: list[Any] = []
# 3. iterate over batches
for batch_idx, docs_batch in enumerate(docs_batches):
async def batch(
batch_idx: int,
docs_batch: list[tuple[str, DocProcessingStatus]],
size_batch: int,
) -> None:
logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.")
# 4. iterate over batch
for doc_id_processing_status in docs_batch:
doc_id, status_doc = doc_id_processing_status
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in self.chunking_func(
status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
# 有文档需要处理,更新 pipeline_status
pipeline_status.update(
{
"busy": True,
"job_name": "indexing files",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "",
}
# Process document (text chunks and full docs) in parallel
tasks = [
self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
}
}
),
self.chunks_vdb.upsert(chunks),
self._process_entity_relation_graph(chunks),
self.full_docs.upsert(
{doc_id: {"content": status_doc.content}}
),
self.text_chunks.upsert(chunks),
]
try:
await asyncio.gather(*tasks)
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}")
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
continue
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
else:
# Another process is busy, just set request flag and return
pipeline_status["request_pending"] = True
logger.info(
"Another process is already processing the document queue. Request queued."
)
return
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
try:
# Process documents until no more documents or requests
while True:
if not to_process_docs:
log_message = "All documents have been processed or are duplicates"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
break
await asyncio.gather(*batches)
await self._insert_done()
# 2. split docs into chunks, insert chunks, update doc status
docs_batches = [
list(to_process_docs.items())[i : i + self.max_parallel_insert]
for i in range(0, len(to_process_docs), self.max_parallel_insert)
]
log_message = f"Number of batches to process: {len(docs_batches)}."
logger.info(log_message)
# Update pipeline status with current batch information
pipeline_status["docs"] += len(to_process_docs)
pipeline_status["batchs"] += len(docs_batches)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
batches: list[Any] = []
# 3. iterate over batches
for batch_idx, docs_batch in enumerate(docs_batches):
# Update current batch in pipeline status (directly, as it's atomic)
pipeline_status["cur_batch"] += 1
async def batch(
batch_idx: int,
docs_batch: list[tuple[str, DocProcessingStatus]],
size_batch: int,
) -> None:
log_message = (
f"Start processing batch {batch_idx + 1} of {size_batch}."
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# 4. iterate over batch
for doc_id_processing_status in docs_batch:
doc_id, status_doc = doc_id_processing_status
# Generate chunks from document
chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in self.chunking_func(
status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
}
# Process document (text chunks and full docs) in parallel
tasks = [
self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSING,
"updated_at": datetime.now().isoformat(),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
}
}
),
self.chunks_vdb.upsert(chunks),
self._process_entity_relation_graph(chunks),
self.full_docs.upsert(
{doc_id: {"content": status_doc.content}}
),
self.text_chunks.upsert(chunks),
]
try:
await asyncio.gather(*tasks)
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
except Exception as e:
logger.error(
f"Failed to process document {doc_id}: {str(e)}"
)
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.FAILED,
"error": str(e),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
)
continue
log_message = (
f"Completed batch {batch_idx + 1} of {len(docs_batches)}."
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
await asyncio.gather(*batches)
await self._insert_done()
# Check if there's a pending request to process more documents (with lock)
has_pending_request = False
async with pipeline_status_lock:
has_pending_request = pipeline_status.get("request_pending", False)
if has_pending_request:
# Clear the request flag before checking for more documents
pipeline_status["request_pending"] = False
if not has_pending_request:
break
log_message = "Processing additional documents due to pending request"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# 获取新的待处理文档
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
to_process_docs = {}
to_process_docs.update(processing_docs)
to_process_docs.update(failed_docs)
to_process_docs.update(pending_docs)
finally:
log_message = "Document processing pipeline completed"
logger.info(log_message)
# Always reset busy status when done or if an exception occurs (with lock)
async with pipeline_status_lock:
pipeline_status["busy"] = False
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
try:
@@ -833,7 +941,16 @@ class LightRAG:
if storage_inst is not None
]
await asyncio.gather(*tasks)
logger.info("All Insert done")
log_message = "All Insert done"
logger.info(log_message)
# 获取 pipeline_status 并更新 latest_message 和 history_messages
from lightrag.kg.shared_storage import get_namespace_data
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
loop = always_get_an_event_loop()

View File

@@ -339,6 +339,9 @@ async def extract_entities(
global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None,
) -> None:
from lightrag.kg.shared_storage import get_namespace_data
pipeline_status = await get_namespace_data("pipeline_status")
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[
@@ -499,9 +502,10 @@ async def extract_entities(
processed_chunks += 1
entities_count = len(maybe_nodes)
relations_count = len(maybe_edges)
logger.info(
f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
)
log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return dict(maybe_nodes), dict(maybe_edges)
tasks = [_process_single_content(c) for c in ordered_chunks]
@@ -530,17 +534,27 @@ async def extract_entities(
)
if not (all_entities_data or all_relationships_data):
logger.info("Didn't extract any entities and relationships.")
log_message = "Didn't extract any entities and relationships."
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return
if not all_entities_data:
logger.info("Didn't extract any entities")
log_message = "Didn't extract any entities"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
if not all_relationships_data:
logger.info("Didn't extract any relationships")
log_message = "Didn't extract any relationships"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
logger.info(
f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
)
log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
verbose_debug(
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
)

View File

@@ -56,6 +56,18 @@ def set_verbose_debug(enabled: bool):
VERBOSE_DEBUG = enabled
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
# Initialize logger
logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root loggger
# Let the main application configure the handlers
logger.setLevel(logging.INFO)
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""
@@ -68,34 +80,6 @@ class UnlimitedSemaphore:
ENCODER = None
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
logger = logging.getLogger("lightrag")
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
def set_logger(log_file: str, level: int = logging.DEBUG):
"""Set up file logging with the specified level.
Args:
log_file: Path to the log file
level: Logging level (e.g. logging.DEBUG, logging.INFO)
"""
logger.setLevel(level)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(level)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(file_handler)
@dataclass
class EmbeddingFunc:

203
run_with_gunicorn.py Executable file
View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python
"""
Start LightRAG server with Gunicorn
"""
import os
import sys
import signal
import pipmaster as pm
from lightrag.api.utils_api import parse_args, display_splash_screen
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
def check_and_install_dependencies():
"""Check and install required dependencies"""
required_packages = [
"gunicorn",
"tiktoken",
"psutil",
# Add other required packages here
]
for package in required_packages:
if not pm.is_installed(package):
print(f"Installing {package}...")
pm.install(package)
print(f"{package} installed successfully")
# Signal handler for graceful shutdown
def signal_handler(sig, frame):
print("\n\n" + "=" * 80)
print("RECEIVED TERMINATION SIGNAL")
print(f"Process ID: {os.getpid()}")
print("=" * 80 + "\n")
# Release shared resources
finalize_share_data()
# Exit with success status
sys.exit(0)
def main():
# Check and install dependencies
check_and_install_dependencies()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # kill command
# Parse all arguments using parse_args
args = parse_args(is_uvicorn_mode=False)
# Display startup information
display_splash_screen(args)
print("🚀 Starting LightRAG with Gunicorn")
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
print("🔍 Preloading app: Enabled")
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
print("\n\n" + "=" * 80)
print("MAIN PROCESS INITIALIZATION")
print(f"Process ID: {os.getpid()}")
print(f"Workers setting: {args.workers}")
print("=" * 80 + "\n")
# Import Gunicorn's StandaloneApplication
from gunicorn.app.base import BaseApplication
# Define a custom application class that loads our config
class GunicornApp(BaseApplication):
def __init__(self, app, options=None):
self.options = options or {}
self.application = app
super().__init__()
def load_config(self):
# Define valid Gunicorn configuration options
valid_options = {
"bind",
"workers",
"worker_class",
"timeout",
"keepalive",
"preload_app",
"errorlog",
"accesslog",
"loglevel",
"certfile",
"keyfile",
"limit_request_line",
"limit_request_fields",
"limit_request_field_size",
"graceful_timeout",
"max_requests",
"max_requests_jitter",
}
# Special hooks that need to be set separately
special_hooks = {
"on_starting",
"on_reload",
"on_exit",
"pre_fork",
"post_fork",
"pre_exec",
"pre_request",
"post_request",
"worker_init",
"worker_exit",
"nworkers_changed",
"child_exit",
}
# Import and configure the gunicorn_config module
import gunicorn_config
# Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1))
)
# Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = (
args.log_level.lower()
if args.log_level
else os.getenv("LOG_LEVEL", "info")
)
# Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = (
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
)
# Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in (
"true",
"1",
"yes",
"t",
"on",
):
gunicorn_config.certfile = (
args.ssl_certfile
if args.ssl_certfile
else os.getenv("SSL_CERTFILE")
)
gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
)
# Set configuration options from the module
for key in dir(gunicorn_config):
if key in valid_options:
value = getattr(gunicorn_config, key)
# Skip functions like on_starting and None values
if not callable(value) and value is not None:
self.cfg.set(key, value)
# Set special hooks
elif key in special_hooks:
value = getattr(gunicorn_config, key)
if callable(value):
self.cfg.set(key, value)
if hasattr(gunicorn_config, "logconfig_dict"):
self.cfg.set(
"logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
)
def load(self):
# Import the application
from lightrag.api.lightrag_server import get_application
return get_application(args)
# Create the application
app = GunicornApp("")
# Force workers to be an integer and greater than 1 for multi-process mode
workers_count = int(args.workers)
if workers_count > 1:
# Set a flag to indicate we're in the main process
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
initialize_share_data(workers_count)
else:
initialize_share_data(1)
# Run the application
print("\nStarting Gunicorn with direct Python API...")
app.run()
if __name__ == "__main__":
main()

View File

@@ -112,6 +112,7 @@ setuptools.setup(
entry_points={
"console_scripts": [
"lightrag-server=lightrag.api.lightrag_server:main [api]",
"lightrag-gunicorn=lightrag.api.run_with_gunicorn:main [api]",
"lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]",
],
},