Merge pull request #969 from danielaskdd/add-multi-worker-support

Add multi workers support for API Server
This commit is contained in:
Yannick Stephan
2025-03-02 17:47:00 +01:00
committed by GitHub
20 changed files with 1927 additions and 455 deletions

1
.gitignore vendored
View File

@@ -21,6 +21,7 @@ site/
# Logs / Reports # Logs / Reports
*.log *.log
*.log.*
*.logfire *.logfire
*.coverage/ *.coverage/
log/ log/

View File

@@ -1,6 +1,9 @@
### This is sample file of .env
### Server Configuration ### Server Configuration
# HOST=0.0.0.0 # HOST=0.0.0.0
# PORT=9621 # PORT=9621
# WORKERS=1
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080 # CORS_ORIGINS=http://localhost:3000,http://localhost:8080
@@ -22,6 +25,9 @@
### Logging level ### Logging level
# LOG_LEVEL=INFO # LOG_LEVEL=INFO
# VERBOSE=False # VERBOSE=False
# LOG_DIR=/path/to/log/directory # Log file directory path, defaults to current working directory
# LOG_MAX_BYTES=10485760 # Log file max size in bytes, defaults to 10MB
# LOG_BACKUP_COUNT=5 # Number of backup files to keep, defaults to 5
### Max async calls for LLM ### Max async calls for LLM
# MAX_ASYNC=4 # MAX_ASYNC=4
@@ -138,3 +144,6 @@ MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
### Qdrant ### Qdrant
QDRANT_URL=http://localhost:16333 QDRANT_URL=http://localhost:16333
# QDRANT_API_KEY=your-api-key # QDRANT_API_KEY=your-api-key
### Redis
REDIS_URI=redis://localhost:6379

View File

@@ -24,6 +24,8 @@ pip install -e ".[api]"
### Starting API Server with Default Settings ### Starting API Server with Default Settings
After installing LightRAG with API support, you can start LightRAG by this command: `lightrag-server`
LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends: LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends:
* ollama * ollama
@@ -92,10 +94,43 @@ LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai-ollama
LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai
# start with ollama llm and ollama embedding (no apikey is needed) # start with ollama llm and ollama embedding (no apikey is needed)
Light_server --llm-binding ollama --embedding-binding ollama light-server --llm-binding ollama --embedding-binding ollama
``` ```
### Starting API Server with Gunicorn (Production)
For production deployments, it's recommended to use Gunicorn as the WSGI server to handle concurrent requests efficiently. LightRAG provides a dedicated Gunicorn startup script that handles shared data initialization, process management, and other critical functionalities.
```bash
# Start with lightrag-gunicorn command
lightrag-gunicorn --workers 4
# Alternatively, you can use the module directly
python -m lightrag.api.run_with_gunicorn --workers 4
```
The `--workers` parameter is crucial for performance:
- Determines how many worker processes Gunicorn will spawn to handle requests
- Each worker can handle concurrent requests using asyncio
- Recommended value is (2 x number_of_cores) + 1
- For example, on a 4-core machine, use 9 workers: (2 x 4) + 1 = 9
- Consider your server's memory when setting this value, as each worker consumes memory
Other important startup parameters:
- `--host`: Server listening address (default: 0.0.0.0)
- `--port`: Server listening port (default: 9621)
- `--timeout`: Request handling timeout (default: 150 seconds)
- `--log-level`: Logging level (default: INFO)
- `--ssl`: Enable HTTPS
- `--ssl-certfile`: Path to SSL certificate file
- `--ssl-keyfile`: Path to SSL private key file
The command line parameters and enviroment variable run_with_gunicorn.py is exactly the same as `light-server`.
### For Azure OpenAI Backend ### For Azure OpenAI Backend
Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)): Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
```bash ```bash
# Change the resource group name, location and OpenAI resource name as needed # Change the resource group name, location and OpenAI resource name as needed

View File

@@ -0,0 +1,187 @@
# gunicorn_config.py
import os
import logging
from lightrag.kg.shared_storage import finalize_share_data
from lightrag.api.lightrag_server import LightragPathFilter
# Get log directory path from environment variable
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
# These variables will be set by run_with_gunicorn.py
workers = None
bind = None
loglevel = None
certfile = None
keyfile = None
# Enable preload_app option
preload_app = True
# Use Uvicorn worker
worker_class = "uvicorn.workers.UvicornWorker"
# Other Gunicorn configurations
timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py
keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s
# Logging configuration
errorlog = os.getenv("ERROR_LOG", log_file_path) # Default write to lightrag.log
accesslog = os.getenv("ACCESS_LOG", log_file_path) # Default write to lightrag.log
logconfig_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "standard",
"stream": "ext://sys.stdout",
},
"file": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "standard",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf8",
},
},
"filters": {
"path_filter": {
"()": "lightrag.api.lightrag_server.LightragPathFilter",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": loglevel.upper() if loglevel else "INFO",
"propagate": False,
},
"gunicorn": {
"handlers": ["console", "file"],
"level": loglevel.upper() if loglevel else "INFO",
"propagate": False,
},
"gunicorn.error": {
"handlers": ["console", "file"],
"level": loglevel.upper() if loglevel else "INFO",
"propagate": False,
},
"gunicorn.access": {
"handlers": ["console", "file"],
"level": loglevel.upper() if loglevel else "INFO",
"propagate": False,
"filters": ["path_filter"],
},
},
}
def on_starting(server):
"""
Executed when Gunicorn starts, before forking the first worker processes
You can use this function to do more initialization tasks for all processes
"""
print("=" * 80)
print(f"GUNICORN MASTER PROCESS: on_starting jobs for {workers} worker(s)")
print(f"Process ID: {os.getpid()}")
print("=" * 80)
# Memory usage monitoring
try:
import psutil
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
msg = (
f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB"
)
print(msg)
except ImportError:
print("psutil not installed, skipping memory usage reporting")
print("Gunicorn initialization complete, forking workers...\n")
def on_exit(server):
"""
Executed when Gunicorn is shutting down.
This is a good place to release shared resources.
"""
print("=" * 80)
print("GUNICORN MASTER PROCESS: Shutting down")
print(f"Process ID: {os.getpid()}")
print("=" * 80)
# Release shared resources
finalize_share_data()
print("=" * 80)
print("Gunicorn shutdown complete")
print("=" * 80)
def post_fork(server, worker):
"""
Executed after a worker has been forked.
This is a good place to set up worker-specific configurations.
"""
# Configure formatters
detailed_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False):
"""Set up a logger with console and file handlers"""
logger_instance = logging.getLogger(logger_name)
logger_instance.setLevel(level)
logger_instance.handlers = [] # Clear existing handlers
logger_instance.propagate = False
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(simple_formatter)
console_handler.setLevel(level)
logger_instance.addHandler(console_handler)
# Add file handler
file_handler = logging.handlers.RotatingFileHandler(
filename=log_file_path,
maxBytes=log_max_bytes,
backupCount=log_backup_count,
encoding="utf-8",
)
file_handler.setFormatter(detailed_formatter)
file_handler.setLevel(level)
logger_instance.addHandler(file_handler)
# Add path filter if requested
if add_filter:
path_filter = LightragPathFilter()
logger_instance.addFilter(path_filter)
# Set up main loggers
log_level = loglevel.upper() if loglevel else "INFO"
setup_logger("uvicorn", log_level)
setup_logger("uvicorn.access", log_level, add_filter=True)
setup_logger("lightrag", log_level, add_filter=True)
# Set up lightrag submodule loggers
for name in logging.root.manager.loggerDict:
if name.startswith("lightrag."):
setup_logger(name, log_level, add_filter=True)
# Disable uvicorn.error logger
uvicorn_error_logger = logging.getLogger("uvicorn.error")
uvicorn_error_logger.handlers = []
uvicorn_error_logger.setLevel(logging.CRITICAL)
uvicorn_error_logger.propagate = False

View File

@@ -8,11 +8,12 @@ from fastapi import (
) )
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
import asyncio import asyncio
import threading
import os import os
from fastapi.staticfiles import StaticFiles
import logging import logging
from typing import Dict import logging.config
import uvicorn
import pipmaster as pm
from fastapi.staticfiles import StaticFiles
from pathlib import Path from pathlib import Path
import configparser import configparser
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
@@ -29,7 +30,6 @@ from lightrag import LightRAG
from lightrag.types import GPTKeywordExtractionFormat from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from lightrag.utils import logger
from .routers.document_routes import ( from .routers.document_routes import (
DocumentManager, DocumentManager,
create_document_routes, create_document_routes,
@@ -39,33 +39,25 @@ from .routers.query_routes import create_query_routes
from .routers.graph_routes import create_graph_routes from .routers.graph_routes import create_graph_routes
from .routers.ollama_api import OllamaAPI from .routers.ollama_api import OllamaAPI
from lightrag.utils import logger, set_verbose_debug
from lightrag.kg.shared_storage import (
get_namespace_data,
get_pipeline_status_lock,
initialize_pipeline_status,
get_all_update_flags_status,
)
# Load environment variables # Load environment variables
try:
load_dotenv(override=True) load_dotenv(override=True)
except Exception as e:
logger.warning(f"Failed to load .env file: {e}")
# Initialize config parser # Initialize config parser
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini") config.read("config.ini")
# Global configuration
global_top_k = 60 # default value
# Global progress tracker class LightragPathFilter(logging.Filter):
scan_progress: Dict = { """Filter for lightrag logger to filter out frequent path access logs"""
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
class AccessLogFilter(logging.Filter):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# Define paths to be filtered # Define paths to be filtered
@@ -73,17 +65,18 @@ class AccessLogFilter(logging.Filter):
def filter(self, record): def filter(self, record):
try: try:
# Check if record has the required attributes for an access log
if not hasattr(record, "args") or not isinstance(record.args, tuple): if not hasattr(record, "args") or not isinstance(record.args, tuple):
return True return True
if len(record.args) < 5: if len(record.args) < 5:
return True return True
# Extract method, path and status from the record args
method = record.args[1] method = record.args[1]
path = record.args[2] path = record.args[2]
status = record.args[4] status = record.args[4]
# print(f"Debug - Method: {method}, Path: {path}, Status: {status}")
# print(f"Debug - Filtered paths: {self.filtered_paths}")
# Filter out successful GET requests to filtered paths
if ( if (
method == "GET" method == "GET"
and (status == 200 or status == 304) and (status == 200 or status == 304)
@@ -92,19 +85,14 @@ class AccessLogFilter(logging.Filter):
return False return False
return True return True
except Exception: except Exception:
# In case of any error, let the message through
return True return True
def create_app(args): def create_app(args):
# Set global top_k # Setup logging
global global_top_k logger.setLevel(args.log_level)
global_top_k = args.top_k # save top_k from args
# Initialize verbose debug setting
from lightrag.utils import set_verbose_debug
set_verbose_debug(args.verbose) set_verbose_debug(args.verbose)
# Verify that bindings are correctly setup # Verify that bindings are correctly setup
@@ -138,11 +126,6 @@ def create_app(args):
if not os.path.exists(args.ssl_keyfile): if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}") raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)
# Check if API key is provided either through env var or args # Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
@@ -158,28 +141,23 @@ def create_app(args):
try: try:
# Initialize database connections # Initialize database connections
await rag.initialize_storages() await rag.initialize_storages()
await initialize_pipeline_status()
# Auto scan documents if enabled # Auto scan documents if enabled
if args.auto_scan_at_startup: if args.auto_scan_at_startup:
# Start scanning in background # Check if a task is already running (with lock protection)
with progress_lock: pipeline_status = await get_namespace_data("pipeline_status")
if not scan_progress["is_scanning"]: should_start_task = False
scan_progress["is_scanning"] = True async with get_pipeline_status_lock():
scan_progress["indexed_count"] = 0 if not pipeline_status.get("busy", False):
scan_progress["progress"] = 0 should_start_task = True
# Only start the task if no other task is running
if should_start_task:
# Create background task # Create background task
task = asyncio.create_task( task = asyncio.create_task(run_scanning_process(rag, doc_manager))
run_scanning_process(rag, doc_manager)
)
app.state.background_tasks.add(task) app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard) task.add_done_callback(app.state.background_tasks.discard)
ASCIIColors.info( logger.info("Auto scan task started at startup.")
f"Started background scanning of documents from {args.input_dir}"
)
else:
ASCIIColors.info(
"Skip document scanning(another scanning is active)"
)
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n") ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
@@ -398,6 +376,9 @@ def create_app(args):
@app.get("/health", dependencies=[Depends(optional_api_key)]) @app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status(): async def get_status():
"""Get current system status""" """Get current system status"""
# Get update flags status for all namespaces
update_status = await get_all_update_flags_status()
return { return {
"status": "healthy", "status": "healthy",
"working_directory": str(args.working_dir), "working_directory": str(args.working_dir),
@@ -417,6 +398,7 @@ def create_app(args):
"graph_storage": args.graph_storage, "graph_storage": args.graph_storage,
"vector_storage": args.vector_storage, "vector_storage": args.vector_storage,
}, },
"update_status": update_status,
} }
# Webui mount webui/index.html # Webui mount webui/index.html
@@ -435,12 +417,30 @@ def create_app(args):
return app return app
def main(): def get_application(args=None):
"""Factory function for creating the FastAPI application"""
if args is None:
args = parse_args() args = parse_args()
import uvicorn return create_app(args)
import logging.config
def configure_logging():
"""Configure logging for uvicorn startup"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger = logging.getLogger(logger_name)
logger.handlers = []
logger.filters = []
# Get log directory path from environment variable
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
# Configure uvicorn logging
logging.config.dictConfig( logging.config.dictConfig(
{ {
"version": 1, "version": 1,
@@ -449,36 +449,106 @@ def main():
"default": { "default": {
"format": "%(levelname)s: %(message)s", "format": "%(levelname)s: %(message)s",
}, },
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
}, },
"handlers": { "handlers": {
"default": { "console": {
"formatter": "default", "formatter": "default",
"class": "logging.StreamHandler", "class": "logging.StreamHandler",
"stream": "ext://sys.stderr", "stream": "ext://sys.stderr",
}, },
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
}, },
"loggers": { "loggers": {
"uvicorn.access": { # Configure all uvicorn related loggers
"handlers": ["default"], "uvicorn": {
"handlers": ["console", "file"],
"level": "INFO", "level": "INFO",
"propagate": False, "propagate": False,
}, },
"uvicorn.access": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
"uvicorn.error": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
"filters": ["path_filter"],
},
},
"filters": {
"path_filter": {
"()": "lightrag.api.lightrag_server.LightragPathFilter",
},
}, },
} }
) )
# Add filter to uvicorn access logger
uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_access_logger.addFilter(AccessLogFilter())
app = create_app(args) def check_and_install_dependencies():
"""Check and install required dependencies"""
required_packages = [
"uvicorn",
"tiktoken",
"fastapi",
# Add other required packages here
]
for package in required_packages:
if not pm.is_installed(package):
print(f"Installing {package}...")
pm.install(package)
print(f"{package} installed successfully")
def main():
# Check if running under Gunicorn
if "GUNICORN_CMD_ARGS" in os.environ:
# If started with Gunicorn, return directly as Gunicorn will call get_application
print("Running under Gunicorn - worker management handled by Gunicorn")
return
# Check and install dependencies
check_and_install_dependencies()
from multiprocessing import freeze_support
freeze_support()
# Configure logging before parsing args
configure_logging()
args = parse_args(is_uvicorn_mode=True)
display_splash_screen(args) display_splash_screen(args)
# Create application instance directly instead of using factory function
app = create_app(args)
# Start Uvicorn in single process mode
uvicorn_config = { uvicorn_config = {
"app": app, "app": app, # Pass application instance directly instead of string path
"host": args.host, "host": args.host,
"port": args.port, "port": args.port,
"log_config": None, # Disable default config "log_config": None, # Disable default config
} }
if args.ssl: if args.ssl:
uvicorn_config.update( uvicorn_config.update(
{ {
@@ -486,6 +556,8 @@ def main():
"ssl_keyfile": args.ssl_keyfile, "ssl_keyfile": args.ssl_keyfile,
} }
) )
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
uvicorn.run(**uvicorn_config) uvicorn.run(**uvicorn_config)

View File

@@ -3,8 +3,7 @@ This module contains all document-related routes for the LightRAG API.
""" """
import asyncio import asyncio
import logging from lightrag.utils import logger
import os
import aiofiles import aiofiles
import shutil import shutil
import traceback import traceback
@@ -12,7 +11,6 @@ import pipmaster as pm
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@@ -23,18 +21,6 @@ from ..utils_api import get_api_key_dependency
router = APIRouter(prefix="/documents", tags=["documents"]) router = APIRouter(prefix="/documents", tags=["documents"])
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = asyncio.Lock()
# Temporary file prefix # Temporary file prefix
temp_prefix = "__tmp__" temp_prefix = "__tmp__"
@@ -161,19 +147,12 @@ class DocumentManager:
"""Scan input directory for new files""" """Scan input directory for new files"""
new_files = [] new_files = []
for ext in self.supported_extensions: for ext in self.supported_extensions:
logging.debug(f"Scanning for {ext} files in {self.input_dir}") logger.debug(f"Scanning for {ext} files in {self.input_dir}")
for file_path in self.input_dir.rglob(f"*{ext}"): for file_path in self.input_dir.rglob(f"*{ext}"):
if file_path not in self.indexed_files: if file_path not in self.indexed_files:
new_files.append(file_path) new_files.append(file_path)
return new_files return new_files
# def scan_directory(self) -> List[Path]:
# new_files = []
# for ext in self.supported_extensions:
# for file_path in self.input_dir.rglob(f"*{ext}"):
# new_files.append(file_path)
# return new_files
def mark_as_indexed(self, file_path: Path): def mark_as_indexed(self, file_path: Path):
self.indexed_files.add(file_path) self.indexed_files.add(file_path)
@@ -287,7 +266,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
) )
content += "\n" content += "\n"
case _: case _:
logging.error( logger.error(
f"Unsupported file type: {file_path.name} (extension {ext})" f"Unsupported file type: {file_path.name} (extension {ext})"
) )
return False return False
@@ -295,20 +274,20 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
# Insert into the RAG queue # Insert into the RAG queue
if content: if content:
await rag.apipeline_enqueue_documents(content) await rag.apipeline_enqueue_documents(content)
logging.info(f"Successfully fetched and enqueued file: {file_path.name}") logger.info(f"Successfully fetched and enqueued file: {file_path.name}")
return True return True
else: else:
logging.error(f"No content could be extracted from file: {file_path.name}") logger.error(f"No content could be extracted from file: {file_path.name}")
except Exception as e: except Exception as e:
logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
finally: finally:
if file_path.name.startswith(temp_prefix): if file_path.name.startswith(temp_prefix):
try: try:
file_path.unlink() file_path.unlink()
except Exception as e: except Exception as e:
logging.error(f"Error deleting file {file_path}: {str(e)}") logger.error(f"Error deleting file {file_path}: {str(e)}")
return False return False
@@ -324,8 +303,8 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path):
await rag.apipeline_process_enqueue_documents() await rag.apipeline_process_enqueue_documents()
except Exception as e: except Exception as e:
logging.error(f"Error indexing file {file_path.name}: {str(e)}") logger.error(f"Error indexing file {file_path.name}: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
@@ -349,8 +328,8 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
if enqueued: if enqueued:
await rag.apipeline_process_enqueue_documents() await rag.apipeline_process_enqueue_documents()
except Exception as e: except Exception as e:
logging.error(f"Error indexing files: {str(e)}") logger.error(f"Error indexing files: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def pipeline_index_texts(rag: LightRAG, texts: List[str]): async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
@@ -393,30 +372,17 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
"""Background task to scan and index documents""" """Background task to scan and index documents"""
try: try:
new_files = doc_manager.scan_directory_for_new_files() new_files = doc_manager.scan_directory_for_new_files()
scan_progress["total_files"] = len(new_files) total_files = len(new_files)
logger.info(f"Found {total_files} new files to index.")
logging.info(f"Found {len(new_files)} new files to index.") for idx, file_path in enumerate(new_files):
for file_path in new_files:
try: try:
async with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
await pipeline_index_file(rag, file_path) await pipeline_index_file(rag, file_path)
except Exception as e:
async with progress_lock: logger.error(f"Error indexing file {file_path}: {str(e)}")
scan_progress["indexed_count"] += 1
scan_progress["progress"] = (
scan_progress["indexed_count"] / scan_progress["total_files"]
) * 100
except Exception as e: except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}") logger.error(f"Error during scanning process: {str(e)}")
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
async with progress_lock:
scan_progress["is_scanning"] = False
def create_document_routes( def create_document_routes(
@@ -436,34 +402,10 @@ def create_document_routes(
Returns: Returns:
dict: A dictionary containing the scanning status dict: A dictionary containing the scanning status
""" """
async with progress_lock:
if scan_progress["is_scanning"]:
return {"status": "already_scanning"}
scan_progress["is_scanning"] = True
scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0
# Start the scanning process in the background # Start the scanning process in the background
background_tasks.add_task(run_scanning_process, rag, doc_manager) background_tasks.add_task(run_scanning_process, rag, doc_manager)
return {"status": "scanning_started"} return {"status": "scanning_started"}
@router.get("/scan-progress")
async def get_scan_progress():
"""
Get the current progress of the document scanning process.
Returns:
dict: A dictionary containing the current scanning progress information including:
- is_scanning: Whether a scan is currently in progress
- current_file: The file currently being processed
- indexed_count: Number of files indexed so far
- total_files: Total number of files to process
- progress: Percentage of completion
"""
async with progress_lock:
return scan_progress
@router.post("/upload", dependencies=[Depends(optional_api_key)]) @router.post("/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir( async def upload_to_input_dir(
background_tasks: BackgroundTasks, file: UploadFile = File(...) background_tasks: BackgroundTasks, file: UploadFile = File(...)
@@ -504,8 +446,8 @@ def create_document_routes(
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
) )
except Exception as e: except Exception as e:
logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post(
@@ -537,8 +479,8 @@ def create_document_routes(
message="Text successfully received. Processing will continue in background.", message="Text successfully received. Processing will continue in background.",
) )
except Exception as e: except Exception as e:
logging.error(f"Error /documents/text: {str(e)}") logger.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post(
@@ -572,8 +514,8 @@ def create_document_routes(
message="Text successfully received. Processing will continue in background.", message="Text successfully received. Processing will continue in background.",
) )
except Exception as e: except Exception as e:
logging.error(f"Error /documents/text: {str(e)}") logger.error(f"Error /documents/text: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post(
@@ -615,8 +557,8 @@ def create_document_routes(
message=f"File '{file.filename}' saved successfully. Processing will continue in background.", message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
) )
except Exception as e: except Exception as e:
logging.error(f"Error /documents/file: {str(e)}") logger.error(f"Error /documents/file: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post( @router.post(
@@ -678,8 +620,8 @@ def create_document_routes(
return InsertResponse(status=status, message=status_message) return InsertResponse(status=status, message=status_message)
except Exception as e: except Exception as e:
logging.error(f"Error /documents/batch: {str(e)}") logger.error(f"Error /documents/batch: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.delete( @router.delete(
@@ -706,8 +648,42 @@ def create_document_routes(
status="success", message="All documents cleared successfully" status="success", message="All documents cleared successfully"
) )
except Exception as e: except Exception as e:
logging.error(f"Error DELETE /documents: {str(e)}") logger.error(f"Error DELETE /documents: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.get("/pipeline_status", dependencies=[Depends(optional_api_key)])
async def get_pipeline_status():
"""
Get the current status of the document indexing pipeline.
This endpoint returns information about the current state of the document processing pipeline,
including whether it's busy, the current job name, when it started, how many documents
are being processed, how many batches there are, and which batch is currently being processed.
Returns:
dict: A dictionary containing the pipeline status information
"""
try:
from lightrag.kg.shared_storage import get_namespace_data
pipeline_status = await get_namespace_data("pipeline_status")
# Convert to regular dict if it's a Manager.dict
status_dict = dict(pipeline_status)
# Convert history_messages to a regular list if it's a Manager.list
if "history_messages" in status_dict:
status_dict["history_messages"] = list(status_dict["history_messages"])
# Format the job_start time if it exists
if status_dict.get("job_start"):
status_dict["job_start"] = str(status_dict["job_start"])
return status_dict
except Exception as e:
logger.error(f"Error getting pipeline status: {str(e)}")
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("", dependencies=[Depends(optional_api_key)]) @router.get("", dependencies=[Depends(optional_api_key)])
@@ -763,8 +739,8 @@ def create_document_routes(
) )
return response return response
except Exception as e: except Exception as e:
logging.error(f"Error GET /documents: {str(e)}") logger.error(f"Error GET /documents: {str(e)}")
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
return router return router

View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python
"""
Start LightRAG server with Gunicorn
"""
import os
import sys
import signal
import pipmaster as pm
from lightrag.api.utils_api import parse_args, display_splash_screen
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
def check_and_install_dependencies():
"""Check and install required dependencies"""
required_packages = [
"gunicorn",
"tiktoken",
"psutil",
# Add other required packages here
]
for package in required_packages:
if not pm.is_installed(package):
print(f"Installing {package}...")
pm.install(package)
print(f"{package} installed successfully")
# Signal handler for graceful shutdown
def signal_handler(sig, frame):
print("\n\n" + "=" * 80)
print("RECEIVED TERMINATION SIGNAL")
print(f"Process ID: {os.getpid()}")
print("=" * 80 + "\n")
# Release shared resources
finalize_share_data()
# Exit with success status
sys.exit(0)
def main():
# Check and install dependencies
check_and_install_dependencies()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # kill command
# Parse all arguments using parse_args
args = parse_args(is_uvicorn_mode=False)
# Display startup information
display_splash_screen(args)
print("🚀 Starting LightRAG with Gunicorn")
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
print("🔍 Preloading app: Enabled")
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
print("\n\n" + "=" * 80)
print("MAIN PROCESS INITIALIZATION")
print(f"Process ID: {os.getpid()}")
print(f"Workers setting: {args.workers}")
print("=" * 80 + "\n")
# Import Gunicorn's StandaloneApplication
from gunicorn.app.base import BaseApplication
# Define a custom application class that loads our config
class GunicornApp(BaseApplication):
def __init__(self, app, options=None):
self.options = options or {}
self.application = app
super().__init__()
def load_config(self):
# Define valid Gunicorn configuration options
valid_options = {
"bind",
"workers",
"worker_class",
"timeout",
"keepalive",
"preload_app",
"errorlog",
"accesslog",
"loglevel",
"certfile",
"keyfile",
"limit_request_line",
"limit_request_fields",
"limit_request_field_size",
"graceful_timeout",
"max_requests",
"max_requests_jitter",
}
# Special hooks that need to be set separately
special_hooks = {
"on_starting",
"on_reload",
"on_exit",
"pre_fork",
"post_fork",
"pre_exec",
"pre_request",
"post_request",
"worker_init",
"worker_exit",
"nworkers_changed",
"child_exit",
}
# Import and configure the gunicorn_config module
from lightrag.api import gunicorn_config
# Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1))
)
# Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = (
args.log_level.lower()
if args.log_level
else os.getenv("LOG_LEVEL", "info")
)
# Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = (
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
)
# Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in (
"true",
"1",
"yes",
"t",
"on",
):
gunicorn_config.certfile = (
args.ssl_certfile
if args.ssl_certfile
else os.getenv("SSL_CERTFILE")
)
gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
)
# Set configuration options from the module
for key in dir(gunicorn_config):
if key in valid_options:
value = getattr(gunicorn_config, key)
# Skip functions like on_starting and None values
if not callable(value) and value is not None:
self.cfg.set(key, value)
# Set special hooks
elif key in special_hooks:
value = getattr(gunicorn_config, key)
if callable(value):
self.cfg.set(key, value)
if hasattr(gunicorn_config, "logconfig_dict"):
self.cfg.set(
"logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
)
def load(self):
# Import the application
from lightrag.api.lightrag_server import get_application
return get_application(args)
# Create the application
app = GunicornApp("")
# Force workers to be an integer and greater than 1 for multi-process mode
workers_count = int(args.workers)
if workers_count > 1:
# Set a flag to indicate we're in the main process
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
initialize_share_data(workers_count)
else:
initialize_share_data(1)
# Run the application
print("\nStarting Gunicorn with direct Python API...")
app.run()
if __name__ == "__main__":
main()

View File

@@ -6,6 +6,7 @@ import os
import argparse import argparse
from typing import Optional from typing import Optional
import sys import sys
import logging
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from fastapi import HTTPException, Security from fastapi import HTTPException, Security
@@ -110,10 +111,13 @@ def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
return default return default
def parse_args() -> argparse.Namespace: def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
""" """
Parse command line arguments with environment variable fallback Parse command line arguments with environment variable fallback
Args:
is_uvicorn_mode: Whether running under uvicorn mode
Returns: Returns:
argparse.Namespace: Parsed arguments argparse.Namespace: Parsed arguments
""" """
@@ -260,6 +264,14 @@ def parse_args() -> argparse.Namespace:
help="Enable automatic scanning when the program starts", help="Enable automatic scanning when the program starts",
) )
# Server workers configuration
parser.add_argument(
"--workers",
type=int,
default=get_env_value("WORKERS", 1, int),
help="Number of worker processes (default: from env or 1)",
)
# LLM and embedding bindings # LLM and embedding bindings
parser.add_argument( parser.add_argument(
"--llm-binding", "--llm-binding",
@@ -278,6 +290,15 @@ def parse_args() -> argparse.Namespace:
args = parser.parse_args() args = parser.parse_args()
# If in uvicorn mode and workers > 1, force it to 1 and log warning
if is_uvicorn_mode and args.workers > 1:
original_workers = args.workers
args.workers = 1
# Log warning directly here
logging.warning(
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
)
# convert relative path to absolute path # convert relative path to absolute path
args.working_dir = os.path.abspath(args.working_dir) args.working_dir = os.path.abspath(args.working_dir)
args.input_dir = os.path.abspath(args.input_dir) args.input_dir = os.path.abspath(args.input_dir)
@@ -346,17 +367,27 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{args.host}") ASCIIColors.yellow(f"{args.host}")
ASCIIColors.white(" ├─ Port: ", end="") ASCIIColors.white(" ├─ Port: ", end="")
ASCIIColors.yellow(f"{args.port}") ASCIIColors.yellow(f"{args.port}")
ASCIIColors.white(" ├─ Workers: ", end="")
ASCIIColors.yellow(f"{args.workers}")
ASCIIColors.white(" ├─ CORS Origins: ", end="") ASCIIColors.white(" ├─ CORS Origins: ", end="")
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
ASCIIColors.white(" ├─ SSL Enabled: ", end="") ASCIIColors.white(" ├─ SSL Enabled: ", end="")
ASCIIColors.yellow(f"{args.ssl}") ASCIIColors.yellow(f"{args.ssl}")
ASCIIColors.white(" └─ API Key: ", end="")
ASCIIColors.yellow("Set" if args.key else "Not Set")
if args.ssl: if args.ssl:
ASCIIColors.white(" ├─ SSL Cert: ", end="") ASCIIColors.white(" ├─ SSL Cert: ", end="")
ASCIIColors.yellow(f"{args.ssl_certfile}") ASCIIColors.yellow(f"{args.ssl_certfile}")
ASCIIColors.white(" ─ SSL Key: ", end="") ASCIIColors.white(" ─ SSL Key: ", end="")
ASCIIColors.yellow(f"{args.ssl_keyfile}") ASCIIColors.yellow(f"{args.ssl_keyfile}")
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
ASCIIColors.white(" ├─ Log Level: ", end="")
ASCIIColors.yellow(f"{args.log_level}")
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" ├─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
ASCIIColors.white(" └─ API Key: ", end="")
ASCIIColors.yellow("Set" if args.key else "Not Set")
# Directory Configuration # Directory Configuration
ASCIIColors.magenta("\n📂 Directory Configuration:") ASCIIColors.magenta("\n📂 Directory Configuration:")
@@ -415,16 +446,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.white(" └─ Document Status Storage: ", end="") ASCIIColors.white(" └─ Document Status Storage: ", end="")
ASCIIColors.yellow(f"{args.doc_status_storage}") ASCIIColors.yellow(f"{args.doc_status_storage}")
ASCIIColors.magenta("\n🛠️ System Configuration:")
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
ASCIIColors.white(" ├─ Log Level: ", end="")
ASCIIColors.yellow(f"{args.log_level}")
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" └─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
# Server Status # Server Status
ASCIIColors.green("\n✨ Server starting up...\n") ASCIIColors.green("\n✨ Server starting up...\n")
@@ -478,7 +499,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.cyan(""" 3. Basic Operations: ASCIIColors.cyan(""" 3. Basic Operations:
- POST /upload_document: Upload new documents to RAG - POST /upload_document: Upload new documents to RAG
- POST /query: Query your document collection - POST /query: Query your document collection
- GET /collections: List available collections
4. Monitor the server: 4. Monitor the server:
- Check server logs for detailed operation information - Check server logs for detailed operation information

View File

@@ -2,25 +2,25 @@ import os
import time import time
import asyncio import asyncio
from typing import Any, final from typing import Any, final
import json import json
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
from lightrag.utils import ( from lightrag.utils import logger, compute_mdhash_id
logger, from lightrag.base import BaseVectorStorage
compute_mdhash_id,
)
from lightrag.base import (
BaseVectorStorage,
)
if not pm.is_installed("faiss"): if not pm.is_installed("faiss"):
pm.install("faiss") pm.install("faiss")
import faiss import faiss # type: ignore
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final @final
@@ -55,14 +55,40 @@ class FaissVectorDBStorage(BaseVectorStorage):
# If you have a large number of vectors, you might want IVF or other indexes. # If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP. # For demonstration, we use a simple IndexFlatIP.
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc. # Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID). # Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta = {} self._id_to_meta = {}
# Attempt to load an existing index + metadata from disk
self._load_faiss_index() self._load_faiss_index()
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_index(self):
"""Check if the shtorage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
logger.info(
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
)
# Reload data
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._index
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -113,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
return [] return []
# Normalize embeddings for cosine similarity (in-place) # Convert to float32 and normalize embeddings for cosine similarity (in-place)
embeddings = embeddings.astype(np.float32)
faiss.normalize_L2(embeddings) faiss.normalize_L2(embeddings)
# Upsert logic: # Upsert logic:
@@ -127,18 +154,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
existing_ids_to_remove.append(faiss_internal_id) existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove: if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove) await self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors # Step 2: Add new vectors
start_idx = self._index.ntotal index = await self._get_index()
self._index.add(embeddings) start_idx = index.ntotal
index.add(embeddings)
# Step 3: Store metadata + vector for each new ID # Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data): for i, meta in enumerate(list_data):
fid = start_idx + i fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed # Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist() meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta[fid] = meta self._id_to_meta.update({fid: meta})
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
@@ -157,7 +185,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
# Perform the similarity search # Perform the similarity search
distances, indices = self._index.search(embedding, top_k) index = await self._get_index()
distances, indices = index.search(embedding, top_k)
distances = distances[0] distances = distances[0]
indices = indices[0] indices = indices[0]
@@ -201,8 +230,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
to_remove.append(fid) to_remove.append(fid)
if to_remove: if to_remove:
self._remove_faiss_ids(to_remove) await self._remove_faiss_ids(to_remove)
logger.info( logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
) )
@@ -223,12 +252,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.debug(f"Found {len(relations)} relations for {entity_name}") logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations: if relations:
self._remove_faiss_ids(relations) await self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}") logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self) -> None:
self._save_faiss_index()
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
# Internal helper methods # Internal helper methods
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
@@ -242,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
return fid return fid
return None return None
def _remove_faiss_ids(self, fid_list): async def _remove_faiss_ids(self, fid_list):
""" """
Remove a list of internal Faiss IDs from the index. Remove a list of internal Faiss IDs from the index.
Because IndexFlatIP doesn't support 'removals', Because IndexFlatIP doesn't support 'removals',
@@ -258,6 +284,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta new_id_to_meta[new_fid] = vec_meta
async with self._storage_lock:
# Re-init index # Re-init index
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep: if vectors_to_keep:
@@ -312,3 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.warning("Starting with an empty Faiss index.") logger.warning("Starting with an empty Faiss index.")
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
async def index_done_callback(self) -> None:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
)
async with self._storage_lock:
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._save_faiss_index()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
return False # Return error
return True # Return success

View File

@@ -12,6 +12,11 @@ from lightrag.utils import (
logger, logger,
write_json, write_json,
) )
from .shared_storage import (
get_namespace_data,
get_storage_lock,
try_initialize_namespace,
)
@final @final
@@ -22,15 +27,30 @@ class JsonDocStatusStorage(DocStatusStorage):
def __post_init__(self): def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data: dict[str, Any] = load_json(self._file_name) or {} self._storage_lock = get_storage_lock()
logger.info(f"Loaded document status storage with {len(self._data)} records") self._data = None
async def initialize(self):
"""Initialize storage data"""
# check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace)
self._data = await get_namespace_data(self.namespace)
if need_init:
loaded_data = load_json(self._file_name) or {}
async with self._storage_lock:
self._data.update(loaded_data)
logger.info(
f"Loaded document status storage with {len(loaded_data)} records"
)
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)""" """Return keys that should be processed (not in storage or not successfully processed)"""
async with self._storage_lock:
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = [] result: list[dict[str, Any]] = []
async with self._storage_lock:
for id in ids: for id in ids:
data = self._data.get(id, None) data = self._data.get(id, None)
if data: if data:
@@ -40,6 +60,7 @@ class JsonDocStatusStorage(DocStatusStorage):
async def get_status_counts(self) -> dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
counts = {status.value: 0 for status in DocStatus} counts = {status.value: 0 for status in DocStatus}
async with self._storage_lock:
for doc in self._data.values(): for doc in self._data.values():
counts[doc["status"]] += 1 counts[doc["status"]] += 1
return counts return counts
@@ -49,6 +70,7 @@ class JsonDocStatusStorage(DocStatusStorage):
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status""" """Get all documents with a specific status"""
result = {} result = {}
async with self._storage_lock:
for k, v in self._data.items(): for k, v in self._data.items():
if v["status"] == status.value: if v["status"] == status.value:
try: try:
@@ -64,24 +86,32 @@ class JsonDocStatusStorage(DocStatusStorage):
return result return result
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
write_json(self._data, self._file_name) async with self._storage_lock:
data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
write_json(data_dict, self._file_name)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
async with self._storage_lock:
self._data.update(data) self._data.update(data)
await self.index_done_callback() await self.index_done_callback()
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
async with self._storage_lock:
return self._data.get(id) return self._data.get(id)
async def delete(self, doc_ids: list[str]): async def delete(self, doc_ids: list[str]):
async with self._storage_lock:
for doc_id in doc_ids: for doc_id in doc_ids:
self._data.pop(doc_id, None) self._data.pop(doc_id, None)
await self.index_done_callback() await self.index_done_callback()
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
async with self._storage_lock:
self._data.clear() self._data.clear()

View File

@@ -1,4 +1,3 @@
import asyncio
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, final from typing import Any, final
@@ -11,6 +10,11 @@ from lightrag.utils import (
logger, logger,
write_json, write_json,
) )
from .shared_storage import (
get_namespace_data,
get_storage_lock,
try_initialize_namespace,
)
@final @final
@@ -19,17 +23,33 @@ class JsonKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data: dict[str, Any] = load_json(self._file_name) or {} self._storage_lock = get_storage_lock()
self._lock = asyncio.Lock() self._data = None
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def initialize(self):
"""Initialize storage data"""
# check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace)
self._data = await get_namespace_data(self.namespace)
if need_init:
loaded_data = load_json(self._file_name) or {}
async with self._storage_lock:
self._data.update(loaded_data)
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
write_json(self._data, self._file_name) async with self._storage_lock:
data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
write_json(data_dict, self._file_name)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._storage_lock:
return self._data.get(id) return self._data.get(id)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async with self._storage_lock:
return [ return [
( (
{k: v for k, v in self._data[id].items()} {k: v for k, v in self._data[id].items()}
@@ -40,16 +60,19 @@ class JsonKVStorage(BaseKVStorage):
] ]
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._storage_lock:
return set(keys) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
async with self._storage_lock:
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data) self._data.update(left_data)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
async with self._storage_lock:
for doc_id in ids: for doc_id in ids:
self._data.pop(doc_id, None) self._data.pop(doc_id, None)
await self.index_done_callback() await self.index_done_callback()

View File

@@ -3,7 +3,6 @@ import os
from typing import Any, final from typing import Any, final
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import time import time
from lightrag.utils import ( from lightrag.utils import (
@@ -11,22 +10,29 @@ from lightrag.utils import (
compute_mdhash_id, compute_mdhash_id,
) )
import pipmaster as pm import pipmaster as pm
from lightrag.base import ( from lightrag.base import BaseVectorStorage
BaseVectorStorage,
)
if not pm.is_installed("nano-vectordb"): if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb") pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final @final
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations # Initialize basic attributes
self._save_lock = asyncio.Lock() self._client = None
self._storage_lock = None
self.storage_updated = None
# Use global config value if specified, otherwise use default # Use global config value if specified, otherwise use default
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -40,10 +46,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB( self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
) )
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_client(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
logger.info(
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
)
# Reload data
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}") logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
@@ -64,6 +103,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
# Execute embedding outside of lock to avoid long lock times
embedding_tasks = [self.embedding_func(batch) for batch in batches] embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = await asyncio.gather(*embedding_tasks) embeddings_list = await asyncio.gather(*embedding_tasks)
@@ -71,7 +111,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
if len(embeddings) == len(list_data): if len(embeddings) == len(list_data):
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data) client = await self._get_client()
results = client.upsert(datas=list_data)
return results return results
else: else:
# sometimes the embedding is not returned correctly. just log it. # sometimes the embedding is not returned correctly. just log it.
@@ -80,9 +121,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid long lock times
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
results = self._client.query(
client = await self._get_client()
results = client.query(
query=embedding, query=embedding,
top_k=top_k, top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold, better_than_threshold=self.cosine_better_than_threshold,
@@ -99,8 +143,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
return results return results
@property @property
def client_storage(self): async def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage") client = await self._get_client()
return getattr(client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]): async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs """Delete vectors with specified IDs
@@ -109,8 +154,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
try: try:
self._client.delete(ids) client = await self._get_client()
logger.info( client.delete(ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}" f"Successfully deleted {len(ids)} vectors from {self.namespace}"
) )
except Exception as e: except Exception as e:
@@ -122,9 +168,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
logger.debug( logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}" f"Attempting to delete entity {entity_name} with ID {entity_id}"
) )
# Check if the entity exists # Check if the entity exists
if self._client.get([entity_id]): client = await self._get_client()
await self.delete([entity_id]) if client.get([entity_id]):
client.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}") logger.debug(f"Successfully deleted entity {entity_name}")
else: else:
logger.debug(f"Entity {entity_name} not found in storage") logger.debug(f"Entity {entity_name} not found in storage")
@@ -133,16 +181,19 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
try: try:
client = await self._get_client()
storage = getattr(client, "_NanoVectorDB__storage")
relations = [ relations = [
dp dp
for dp in self.client_storage["data"] for dp in storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
] ]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}") logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
ids_to_delete = [relation["__id__"] for relation in relations] ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete: if ids_to_delete:
await self.delete(ids_to_delete) client = await self._get_client()
client.delete(ids_to_delete)
logger.debug( logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}" f"Deleted {len(ids_to_delete)} relations for {entity_name}"
) )
@@ -151,6 +202,37 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self) -> None: async def index_done_callback(self) -> bool:
async with self._save_lock: """Save data to disk"""
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._client.save() self._client.save()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving data for {self.namespace}: {e}")
return False # Return error
return True # Return success

View File

@@ -1,18 +1,12 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, final from typing import Any, final
import numpy as np import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import ( from lightrag.utils import logger
logger, from lightrag.base import BaseGraphStorage
)
from lightrag.base import (
BaseGraphStorage,
)
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("networkx"): if not pm.is_installed("networkx"):
@@ -23,6 +17,12 @@ if not pm.is_installed("graspologic"):
import networkx as nx import networkx as nx
from graspologic import embed from graspologic import embed
from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final @final
@@ -78,56 +78,101 @@ class NetworkXStorage(BaseGraphStorage):
self._graphml_xml_file = os.path.join( self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml" self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
) )
self._storage_lock = None
self.storage_updated = None
self._graph = None
# Load initial graph
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None: if preloaded_graph is not None:
logger.info( logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
) )
else:
logger.info("Created new empty graph")
self._graph = preloaded_graph or nx.Graph() self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def index_done_callback(self) -> None: async def initialize(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) """Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_graph(self):
"""Check if the storage should be reloaded"""
# Acquire lock to prevent concurrent read and write
async with self._storage_lock:
# Check if data needs to be reloaded
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
logger.info(
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
)
# Reload data
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._graph
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id) graph = await self._get_graph()
return graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id) graph = await self._get_graph()
return graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
return self._graph.nodes.get(node_id) graph = await self._get_graph()
return graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
return self._graph.degree(node_id) graph = await self._get_graph()
return graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._graph.degree(src_id) + self._graph.degree(tgt_id) graph = await self._get_graph()
return graph.degree(src_id) + graph.degree(tgt_id)
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
return self._graph.edges.get((source_node_id, target_node_id)) graph = await self._get_graph()
return graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
if self._graph.has_node(source_node_id): graph = await self._get_graph()
return list(self._graph.edges(source_node_id)) if graph.has_node(source_node_id):
return list(graph.edges(source_node_id))
return None return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
self._graph.add_node(node_id, **node_data) graph = await self._get_graph()
graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
self._graph.add_edge(source_node_id, target_node_id, **edge_data) graph = await self._get_graph()
graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
if self._graph.has_node(node_id): graph = await self._get_graph()
self._graph.remove_node(node_id) if graph.has_node(node_id):
logger.info(f"Node {node_id} deleted from the graph.") graph.remove_node(node_id)
logger.debug(f"Node {node_id} deleted from the graph.")
else: else:
logger.warning(f"Node {node_id} not found in the graph for deletion.") logger.warning(f"Node {node_id} not found in the graph for deletion.")
@@ -138,35 +183,37 @@ class NetworkXStorage(BaseGraphStorage):
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED # TODO: NOT USED
async def _node2vec_embed(self): async def _node2vec_embed(self):
graph = await self._get_graph()
embeddings, nodes = embed.node2vec_embed( embeddings, nodes = embed.node2vec_embed(
self._graph, graph,
**self.global_config["node2vec_params"], **self.global_config["node2vec_params"],
) )
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]): async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes """Delete multiple nodes
Args: Args:
nodes: List of node IDs to be deleted nodes: List of node IDs to be deleted
""" """
graph = await self._get_graph()
for node in nodes: for node in nodes:
if self._graph.has_node(node): if graph.has_node(node):
self._graph.remove_node(node) graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]): async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges """Delete multiple edges
Args: Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple edges: List of edges to be deleted, each edge is a (source, target) tuple
""" """
graph = await self._get_graph()
for source, target in edges: for source, target in edges:
if self._graph.has_edge(source, target): if graph.has_edge(source, target):
self._graph.remove_edge(source, target) graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """
@@ -174,8 +221,9 @@ class NetworkXStorage(BaseGraphStorage):
Returns: Returns:
[label1, label2, ...] # Alphabetically sorted label list [label1, label2, ...] # Alphabetically sorted label list
""" """
graph = await self._get_graph()
labels = set() labels = set()
for node in self._graph.nodes(): for node in graph.nodes():
labels.add(str(node)) # Add node id as a label labels.add(str(node)) # Add node id as a label
# Return sorted list # Return sorted list
@@ -198,16 +246,18 @@ class NetworkXStorage(BaseGraphStorage):
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
graph = await self._get_graph()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", return the entire graph including all nodes and edges # For "*", return the entire graph including all nodes and edges
subgraph = ( subgraph = (
self._graph.copy() graph.copy()
) # Create a copy to avoid modifying the original graph ) # Create a copy to avoid modifying the original graph
else: else:
# Find nodes with matching node id (partial match) # Find nodes with matching node id (partial match)
nodes_to_explore = [] nodes_to_explore = []
for n, attr in self._graph.nodes(data=True): for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching if node_label in str(n): # Use partial matching
nodes_to_explore.append(n) nodes_to_explore.append(n)
@@ -216,7 +266,7 @@ class NetworkXStorage(BaseGraphStorage):
return result return result
# Get subgraph using ego_graph # Get subgraph using ego_graph
subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth) subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500 max_graph_nodes = 500
@@ -278,9 +328,41 @@ class NetworkXStorage(BaseGraphStorage):
) )
seen_edges.add(edge_id) seen_edges.add(edge_id)
# logger.info(result.edges)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
) )
return result return result
async def index_done_callback(self) -> bool:
"""Save data to disk"""
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Graph for {self.namespace} was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return True # Return success
except Exception as e:
logger.error(f"Error saving graph for {self.namespace}: {e}")
return False # Return error
return True

View File

@@ -38,8 +38,8 @@ import pipmaster as pm
if not pm.is_installed("asyncpg"): if not pm.is_installed("asyncpg"):
pm.install("asyncpg") pm.install("asyncpg")
import asyncpg import asyncpg # type: ignore
from asyncpg import Pool from asyncpg import Pool # type: ignore
class PostgreSQLDB: class PostgreSQLDB:

View File

@@ -0,0 +1,374 @@
import os
import sys
import asyncio
from multiprocessing.synchronize import Lock as ProcessLock
from multiprocessing import Manager
from typing import Any, Dict, Optional, Union, TypeVar, Generic
# Define a direct print function for critical logs that must be visible in all processes
def direct_log(message, level="INFO"):
"""
Log a message directly to stderr to ensure visibility in all processes,
including the Gunicorn master process.
"""
print(f"{level}: {message}", file=sys.stderr, flush=True)
T = TypeVar("T")
LockType = Union[ProcessLock, asyncio.Lock]
is_multiprocess = None
_workers = None
_manager = None
_initialized = None
# shared data for storage across processes
_shared_dicts: Optional[Dict[str, Any]] = None
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
# locks for mutex access
_storage_lock: Optional[LockType] = None
_internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None
class UnifiedLock(Generic[T]):
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
self._lock = lock
self._is_async = is_async
async def __aenter__(self) -> "UnifiedLock[T]":
if self._is_async:
await self._lock.acquire()
else:
self._lock.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._is_async:
self._lock.release()
else:
self._lock.release()
def __enter__(self) -> "UnifiedLock[T]":
"""For backward compatibility"""
if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock")
self._lock.acquire()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""For backward compatibility"""
if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock")
self._lock.release()
def get_internal_lock() -> UnifiedLock:
"""return unified storage lock for data consistency"""
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
def get_storage_lock() -> UnifiedLock:
"""return unified storage lock for data consistency"""
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
def get_pipeline_status_lock() -> UnifiedLock:
"""return unified storage lock for data consistency"""
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
def initialize_share_data(workers: int = 1):
"""
Initialize shared storage data for single or multi-process mode.
When used with Gunicorn's preload feature, this function is called once in the
master process before forking worker processes, allowing all workers to share
the same initialized data.
In single-process mode, this function is called in FASTAPI lifespan function.
The function determines whether to use cross-process shared variables for data storage
based on the number of workers. If workers=1, it uses thread locks and local dictionaries.
If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager.
Args:
workers (int): Number of worker processes. If 1, single-process mode is used.
If > 1, multi-process mode with shared memory is used.
"""
global \
_manager, \
_workers, \
is_multiprocess, \
_storage_lock, \
_internal_lock, \
_pipeline_status_lock, \
_shared_dicts, \
_init_flags, \
_initialized, \
_update_flags
# Check if already initialized
if _initialized:
direct_log(
f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})"
)
return
_manager = Manager()
_workers = workers
if workers > 1:
is_multiprocess = True
_internal_lock = _manager.Lock()
_storage_lock = _manager.Lock()
_pipeline_status_lock = _manager.Lock()
_shared_dicts = _manager.dict()
_init_flags = _manager.dict()
_update_flags = _manager.dict()
direct_log(
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
)
else:
is_multiprocess = False
_internal_lock = asyncio.Lock()
_storage_lock = asyncio.Lock()
_pipeline_status_lock = asyncio.Lock()
_shared_dicts = {}
_init_flags = {}
_update_flags = {}
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
# Mark as initialized
_initialized = True
async def initialize_pipeline_status():
"""
Initialize pipeline namespace with default values.
This function is called during FASTAPI lifespan for each worker.
"""
pipeline_namespace = await get_namespace_data("pipeline_status")
async with get_internal_lock():
# Check if already initialized by checking for required fields
if "busy" in pipeline_namespace:
return
# Create a shared list object for history_messages
history_messages = _manager.list() if is_multiprocess else []
pipeline_namespace.update(
{
"busy": False, # Control concurrent processes
"job_name": "Default Job", # Current job name (indexing files/indexing texts)
"job_start": None, # Job start time
"docs": 0, # Total number of documents to be indexed
"batchs": 0, # Number of batches for processing documents
"cur_batch": 0, # Current processing batch
"request_pending": False, # Flag for pending request for processing
"latest_message": "", # Latest message from pipeline processing
"history_messages": history_messages, # 使用共享列表对象
}
)
direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
async def get_update_flag(namespace: str):
"""
Create a namespace's update flag for a workers.
Returen the update flag to caller for referencing or reset.
"""
global _update_flags
if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized")
async with get_internal_lock():
if namespace not in _update_flags:
if is_multiprocess and _manager is not None:
_update_flags[namespace] = _manager.list()
else:
_update_flags[namespace] = []
direct_log(
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
)
if is_multiprocess and _manager is not None:
new_update_flag = _manager.Value("b", False)
else:
new_update_flag = False
_update_flags[namespace].append(new_update_flag)
return new_update_flag
async def set_all_update_flags(namespace: str):
"""Set all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags
if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized")
async with get_internal_lock():
if namespace not in _update_flags:
raise ValueError(f"Namespace {namespace} not found in update flags")
# Update flags for both modes
for i in range(len(_update_flags[namespace])):
if is_multiprocess:
_update_flags[namespace][i].value = True
else:
_update_flags[namespace][i] = True
async def get_all_update_flags_status() -> Dict[str, list]:
"""
Get update flags status for all namespaces.
Returns:
Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses
"""
if _update_flags is None:
return {}
result = {}
async with get_internal_lock():
for namespace, flags in _update_flags.items():
worker_statuses = []
for flag in flags:
if is_multiprocess:
worker_statuses.append(flag.value)
else:
worker_statuses.append(flag)
result[namespace] = worker_statuses
return result
def try_initialize_namespace(namespace: str) -> bool:
"""
Returns True if the current worker(process) gets initialization permission for loading data later.
The worker does not get the permission is prohibited to load data from files.
"""
global _init_flags, _manager
if _init_flags is None:
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
if namespace not in _init_flags:
_init_flags[namespace] = True
direct_log(
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
)
return True
direct_log(
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
)
return False
async def get_namespace_data(namespace: str) -> Dict[str, Any]:
"""get the shared data reference for specific namespace"""
if _shared_dicts is None:
direct_log(
f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}",
level="ERROR",
)
raise ValueError("Shared dictionaries not initialized")
async with get_internal_lock():
if namespace not in _shared_dicts:
if is_multiprocess and _manager is not None:
_shared_dicts[namespace] = _manager.dict()
else:
_shared_dicts[namespace] = {}
return _shared_dicts[namespace]
def finalize_share_data():
"""
Release shared resources and clean up.
This function should be called when the application is shutting down
to properly release shared resources and avoid memory leaks.
In multi-process mode, it shuts down the Manager and releases all shared objects.
In single-process mode, it simply resets the global variables.
"""
global \
_manager, \
is_multiprocess, \
_storage_lock, \
_internal_lock, \
_pipeline_status_lock, \
_shared_dicts, \
_init_flags, \
_initialized, \
_update_flags
# Check if already initialized
if not _initialized:
direct_log(
f"Process {os.getpid()} storage data not initialized, nothing to finalize"
)
return
direct_log(
f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})"
)
# In multi-process mode, shut down the Manager
if is_multiprocess and _manager is not None:
try:
# Clear shared resources before shutting down Manager
if _shared_dicts is not None:
# Clear pipeline status history messages first if exists
try:
pipeline_status = _shared_dicts.get("pipeline_status", {})
if "history_messages" in pipeline_status:
pipeline_status["history_messages"].clear()
except Exception:
pass # Ignore any errors during history messages cleanup
_shared_dicts.clear()
if _init_flags is not None:
_init_flags.clear()
if _update_flags is not None:
# Clear each namespace's update flags list and Value objects
try:
for namespace in _update_flags:
flags_list = _update_flags[namespace]
if isinstance(flags_list, list):
# Clear Value objects in the list
for flag in flags_list:
if hasattr(
flag, "value"
): # Check if it's a Value object
flag.value = False
flags_list.clear()
except Exception:
pass # Ignore any errors during update flags cleanup
_update_flags.clear()
# Shut down the Manager - this will automatically clean up all shared resources
_manager.shutdown()
direct_log(f"Process {os.getpid()} Manager shutdown complete")
except Exception as e:
direct_log(
f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR"
)
# Reset global variables
_manager = None
_initialized = None
is_multiprocess = None
_shared_dicts = None
_init_flags = None
_storage_lock = None
_internal_lock = None
_pipeline_status_lock = None
_update_flags = None
direct_log(f"Process {os.getpid()} storage data finalization complete")

View File

@@ -45,7 +45,6 @@ from .utils import (
lazy_external_import, lazy_external_import,
limit_async_func_call, limit_async_func_call,
logger, logger,
set_logger,
) )
from .types import KnowledgeGraph from .types import KnowledgeGraph
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -268,9 +267,14 @@ class LightRAG:
def __post_init__(self): def __post_init__(self):
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path, self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")
from lightrag.kg.shared_storage import (
initialize_share_data,
)
initialize_share_data()
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
@@ -692,8 +696,20 @@ class LightRAG:
3. Process each chunk for entity and relation extraction 3. Process each chunk for entity and relation extraction
4. Update the document status 4. Update the document status
""" """
# 1. Get all pending, failed, and abnormally terminated processing documents. from lightrag.kg.shared_storage import (
# Run the asynchronous status retrievals in parallel using asyncio.gather get_namespace_data,
get_pipeline_status_lock,
)
# Get pipeline status shared data and lock
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status_lock = get_pipeline_status_lock()
# Check if another process is already processing the queue
async with pipeline_status_lock:
# Ensure only one worker is processing documents
if not pipeline_status.get("busy", False):
# 先检查是否有需要处理的文档
processing_docs, failed_docs, pending_docs = await asyncio.gather( processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING), self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED), self.doc_status.get_docs_by_status(DocStatus.FAILED),
@@ -705,28 +721,76 @@ class LightRAG:
to_process_docs.update(failed_docs) to_process_docs.update(failed_docs)
to_process_docs.update(pending_docs) to_process_docs.update(pending_docs)
# 如果没有需要处理的文档,直接返回,保留 pipeline_status 中的内容不变
if not to_process_docs: if not to_process_docs:
logger.info("All documents have been processed or are duplicates") logger.info("No documents to process")
return return
# 有文档需要处理,更新 pipeline_status
pipeline_status.update(
{
"busy": True,
"job_name": "indexing files",
"job_start": datetime.now().isoformat(),
"docs": 0,
"batchs": 0,
"cur_batch": 0,
"request_pending": False, # Clear any previous request
"latest_message": "",
}
)
# Cleaning history_messages without breaking it as a shared list object
del pipeline_status["history_messages"][:]
else:
# Another process is busy, just set request flag and return
pipeline_status["request_pending"] = True
logger.info(
"Another process is already processing the document queue. Request queued."
)
return
try:
# Process documents until no more documents or requests
while True:
if not to_process_docs:
log_message = "All documents have been processed or are duplicates"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
break
# 2. split docs into chunks, insert chunks, update doc status # 2. split docs into chunks, insert chunks, update doc status
docs_batches = [ docs_batches = [
list(to_process_docs.items())[i : i + self.max_parallel_insert] list(to_process_docs.items())[i : i + self.max_parallel_insert]
for i in range(0, len(to_process_docs), self.max_parallel_insert) for i in range(0, len(to_process_docs), self.max_parallel_insert)
] ]
logger.info(f"Number of batches to process: {len(docs_batches)}.") log_message = f"Number of batches to process: {len(docs_batches)}."
logger.info(log_message)
# Update pipeline status with current batch information
pipeline_status["docs"] += len(to_process_docs)
pipeline_status["batchs"] += len(docs_batches)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
batches: list[Any] = [] batches: list[Any] = []
# 3. iterate over batches # 3. iterate over batches
for batch_idx, docs_batch in enumerate(docs_batches): for batch_idx, docs_batch in enumerate(docs_batches):
# Update current batch in pipeline status (directly, as it's atomic)
pipeline_status["cur_batch"] += 1
async def batch( async def batch(
batch_idx: int, batch_idx: int,
docs_batch: list[tuple[str, DocProcessingStatus]], docs_batch: list[tuple[str, DocProcessingStatus]],
size_batch: int, size_batch: int,
) -> None: ) -> None:
logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.") log_message = (
f"Start processing batch {batch_idx + 1} of {size_batch}."
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# 4. iterate over batch # 4. iterate over batch
for doc_id_processing_status in docs_batch: for doc_id_processing_status in docs_batch:
doc_id, status_doc = doc_id_processing_status doc_id, status_doc = doc_id_processing_status
@@ -782,7 +846,9 @@ class LightRAG:
} }
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to process document {doc_id}: {str(e)}") logger.error(
f"Failed to process document {doc_id}: {str(e)}"
)
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
doc_id: { doc_id: {
@@ -797,13 +863,55 @@ class LightRAG:
} }
) )
continue continue
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.") log_message = (
f"Completed batch {batch_idx + 1} of {len(docs_batches)}."
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
batches.append(batch(batch_idx, docs_batch, len(docs_batches))) batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
await asyncio.gather(*batches) await asyncio.gather(*batches)
await self._insert_done() await self._insert_done()
# Check if there's a pending request to process more documents (with lock)
has_pending_request = False
async with pipeline_status_lock:
has_pending_request = pipeline_status.get("request_pending", False)
if has_pending_request:
# Clear the request flag before checking for more documents
pipeline_status["request_pending"] = False
if not has_pending_request:
break
log_message = "Processing additional documents due to pending request"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# 获取新的待处理文档
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
to_process_docs = {}
to_process_docs.update(processing_docs)
to_process_docs.update(failed_docs)
to_process_docs.update(pending_docs)
finally:
log_message = "Document processing pipeline completed"
logger.info(log_message)
# Always reset busy status when done or if an exception occurs (with lock)
async with pipeline_status_lock:
pipeline_status["busy"] = False
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
try: try:
await extract_entities( await extract_entities(
@@ -833,7 +941,16 @@ class LightRAG:
if storage_inst is not None if storage_inst is not None
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
logger.info("All Insert done")
log_message = "All Insert done"
logger.info(log_message)
# 获取 pipeline_status 并更新 latest_message 和 history_messages
from lightrag.kg.shared_storage import get_namespace_data
pipeline_status = await get_namespace_data("pipeline_status")
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None: def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
loop = always_get_an_event_loop() loop = always_get_an_event_loop()

View File

@@ -339,6 +339,9 @@ async def extract_entities(
global_config: dict[str, str], global_config: dict[str, str],
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
) -> None: ) -> None:
from lightrag.kg.shared_storage import get_namespace_data
pipeline_status = await get_namespace_data("pipeline_status")
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
enable_llm_cache_for_entity_extract: bool = global_config[ enable_llm_cache_for_entity_extract: bool = global_config[
@@ -499,9 +502,10 @@ async def extract_entities(
processed_chunks += 1 processed_chunks += 1
entities_count = len(maybe_nodes) entities_count = len(maybe_nodes)
relations_count = len(maybe_edges) relations_count = len(maybe_edges)
logger.info( log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)" logger.info(log_message)
) pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return dict(maybe_nodes), dict(maybe_edges) return dict(maybe_nodes), dict(maybe_edges)
tasks = [_process_single_content(c) for c in ordered_chunks] tasks = [_process_single_content(c) for c in ordered_chunks]
@@ -530,17 +534,27 @@ async def extract_entities(
) )
if not (all_entities_data or all_relationships_data): if not (all_entities_data or all_relationships_data):
logger.info("Didn't extract any entities and relationships.") log_message = "Didn't extract any entities and relationships."
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return return
if not all_entities_data: if not all_entities_data:
logger.info("Didn't extract any entities") log_message = "Didn't extract any entities"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
if not all_relationships_data: if not all_relationships_data:
logger.info("Didn't extract any relationships") log_message = "Didn't extract any relationships"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
logger.info( log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)" logger.info(log_message)
) pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
verbose_debug( verbose_debug(
f"New entities:{all_entities_data}, relationships:{all_relationships_data}" f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
) )

View File

@@ -56,6 +56,18 @@ def set_verbose_debug(enabled: bool):
VERBOSE_DEBUG = enabled VERBOSE_DEBUG = enabled
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
# Initialize logger
logger = logging.getLogger("lightrag")
logger.propagate = False # prevent log message send to root loggger
# Let the main application configure the handlers
logger.setLevel(logging.INFO)
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
class UnlimitedSemaphore: class UnlimitedSemaphore:
"""A context manager that allows unlimited access.""" """A context manager that allows unlimited access."""
@@ -68,34 +80,6 @@ class UnlimitedSemaphore:
ENCODER = None ENCODER = None
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
logger = logging.getLogger("lightrag")
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
def set_logger(log_file: str, level: int = logging.DEBUG):
"""Set up file logging with the specified level.
Args:
log_file: Path to the log file
level: Logging level (e.g. logging.DEBUG, logging.INFO)
"""
logger.setLevel(level)
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(level)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(file_handler)
@dataclass @dataclass
class EmbeddingFunc: class EmbeddingFunc:

203
run_with_gunicorn.py Executable file
View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python
"""
Start LightRAG server with Gunicorn
"""
import os
import sys
import signal
import pipmaster as pm
from lightrag.api.utils_api import parse_args, display_splash_screen
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
def check_and_install_dependencies():
"""Check and install required dependencies"""
required_packages = [
"gunicorn",
"tiktoken",
"psutil",
# Add other required packages here
]
for package in required_packages:
if not pm.is_installed(package):
print(f"Installing {package}...")
pm.install(package)
print(f"{package} installed successfully")
# Signal handler for graceful shutdown
def signal_handler(sig, frame):
print("\n\n" + "=" * 80)
print("RECEIVED TERMINATION SIGNAL")
print(f"Process ID: {os.getpid()}")
print("=" * 80 + "\n")
# Release shared resources
finalize_share_data()
# Exit with success status
sys.exit(0)
def main():
# Check and install dependencies
check_and_install_dependencies()
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # kill command
# Parse all arguments using parse_args
args = parse_args(is_uvicorn_mode=False)
# Display startup information
display_splash_screen(args)
print("🚀 Starting LightRAG with Gunicorn")
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
print("🔍 Preloading app: Enabled")
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
print("\n\n" + "=" * 80)
print("MAIN PROCESS INITIALIZATION")
print(f"Process ID: {os.getpid()}")
print(f"Workers setting: {args.workers}")
print("=" * 80 + "\n")
# Import Gunicorn's StandaloneApplication
from gunicorn.app.base import BaseApplication
# Define a custom application class that loads our config
class GunicornApp(BaseApplication):
def __init__(self, app, options=None):
self.options = options or {}
self.application = app
super().__init__()
def load_config(self):
# Define valid Gunicorn configuration options
valid_options = {
"bind",
"workers",
"worker_class",
"timeout",
"keepalive",
"preload_app",
"errorlog",
"accesslog",
"loglevel",
"certfile",
"keyfile",
"limit_request_line",
"limit_request_fields",
"limit_request_field_size",
"graceful_timeout",
"max_requests",
"max_requests_jitter",
}
# Special hooks that need to be set separately
special_hooks = {
"on_starting",
"on_reload",
"on_exit",
"pre_fork",
"post_fork",
"pre_exec",
"pre_request",
"post_request",
"worker_init",
"worker_exit",
"nworkers_changed",
"child_exit",
}
# Import and configure the gunicorn_config module
import gunicorn_config
# Set configuration variables in gunicorn_config, prioritizing command line arguments
gunicorn_config.workers = (
args.workers if args.workers else int(os.getenv("WORKERS", 1))
)
# Bind configuration prioritizes command line arguments
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
gunicorn_config.bind = f"{host}:{port}"
# Log level configuration prioritizes command line arguments
gunicorn_config.loglevel = (
args.log_level.lower()
if args.log_level
else os.getenv("LOG_LEVEL", "info")
)
# Timeout configuration prioritizes command line arguments
gunicorn_config.timeout = (
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
)
# Keepalive configuration
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
# SSL configuration prioritizes command line arguments
if args.ssl or os.getenv("SSL", "").lower() in (
"true",
"1",
"yes",
"t",
"on",
):
gunicorn_config.certfile = (
args.ssl_certfile
if args.ssl_certfile
else os.getenv("SSL_CERTFILE")
)
gunicorn_config.keyfile = (
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
)
# Set configuration options from the module
for key in dir(gunicorn_config):
if key in valid_options:
value = getattr(gunicorn_config, key)
# Skip functions like on_starting and None values
if not callable(value) and value is not None:
self.cfg.set(key, value)
# Set special hooks
elif key in special_hooks:
value = getattr(gunicorn_config, key)
if callable(value):
self.cfg.set(key, value)
if hasattr(gunicorn_config, "logconfig_dict"):
self.cfg.set(
"logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
)
def load(self):
# Import the application
from lightrag.api.lightrag_server import get_application
return get_application(args)
# Create the application
app = GunicornApp("")
# Force workers to be an integer and greater than 1 for multi-process mode
workers_count = int(args.workers)
if workers_count > 1:
# Set a flag to indicate we're in the main process
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
initialize_share_data(workers_count)
else:
initialize_share_data(1)
# Run the application
print("\nStarting Gunicorn with direct Python API...")
app.run()
if __name__ == "__main__":
main()

View File

@@ -112,6 +112,7 @@ setuptools.setup(
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"lightrag-server=lightrag.api.lightrag_server:main [api]", "lightrag-server=lightrag.api.lightrag_server:main [api]",
"lightrag-gunicorn=lightrag.api.run_with_gunicorn:main [api]",
"lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]", "lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]",
], ],
}, },