Merge pull request #969 from danielaskdd/add-multi-worker-support
Add multi workers support for API Server
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,6 +21,7 @@ site/
|
||||
|
||||
# Logs / Reports
|
||||
*.log
|
||||
*.log.*
|
||||
*.logfire
|
||||
*.coverage/
|
||||
log/
|
||||
|
@@ -1,6 +1,9 @@
|
||||
### This is sample file of .env
|
||||
|
||||
### Server Configuration
|
||||
# HOST=0.0.0.0
|
||||
# PORT=9621
|
||||
# WORKERS=1
|
||||
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
|
||||
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
||||
|
||||
@@ -22,6 +25,9 @@
|
||||
### Logging level
|
||||
# LOG_LEVEL=INFO
|
||||
# VERBOSE=False
|
||||
# LOG_DIR=/path/to/log/directory # Log file directory path, defaults to current working directory
|
||||
# LOG_MAX_BYTES=10485760 # Log file max size in bytes, defaults to 10MB
|
||||
# LOG_BACKUP_COUNT=5 # Number of backup files to keep, defaults to 5
|
||||
|
||||
### Max async calls for LLM
|
||||
# MAX_ASYNC=4
|
||||
@@ -138,3 +144,6 @@ MONGODB_GRAPH=false # deprecated (keep for backward compatibility)
|
||||
### Qdrant
|
||||
QDRANT_URL=http://localhost:16333
|
||||
# QDRANT_API_KEY=your-api-key
|
||||
|
||||
### Redis
|
||||
REDIS_URI=redis://localhost:6379
|
@@ -24,6 +24,8 @@ pip install -e ".[api]"
|
||||
|
||||
### Starting API Server with Default Settings
|
||||
|
||||
After installing LightRAG with API support, you can start LightRAG by this command: `lightrag-server`
|
||||
|
||||
LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends:
|
||||
|
||||
* ollama
|
||||
@@ -92,10 +94,43 @@ LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai-ollama
|
||||
LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai
|
||||
|
||||
# start with ollama llm and ollama embedding (no apikey is needed)
|
||||
Light_server --llm-binding ollama --embedding-binding ollama
|
||||
light-server --llm-binding ollama --embedding-binding ollama
|
||||
```
|
||||
|
||||
### Starting API Server with Gunicorn (Production)
|
||||
|
||||
For production deployments, it's recommended to use Gunicorn as the WSGI server to handle concurrent requests efficiently. LightRAG provides a dedicated Gunicorn startup script that handles shared data initialization, process management, and other critical functionalities.
|
||||
|
||||
```bash
|
||||
# Start with lightrag-gunicorn command
|
||||
lightrag-gunicorn --workers 4
|
||||
|
||||
# Alternatively, you can use the module directly
|
||||
python -m lightrag.api.run_with_gunicorn --workers 4
|
||||
```
|
||||
|
||||
The `--workers` parameter is crucial for performance:
|
||||
|
||||
- Determines how many worker processes Gunicorn will spawn to handle requests
|
||||
- Each worker can handle concurrent requests using asyncio
|
||||
- Recommended value is (2 x number_of_cores) + 1
|
||||
- For example, on a 4-core machine, use 9 workers: (2 x 4) + 1 = 9
|
||||
- Consider your server's memory when setting this value, as each worker consumes memory
|
||||
|
||||
Other important startup parameters:
|
||||
|
||||
- `--host`: Server listening address (default: 0.0.0.0)
|
||||
- `--port`: Server listening port (default: 9621)
|
||||
- `--timeout`: Request handling timeout (default: 150 seconds)
|
||||
- `--log-level`: Logging level (default: INFO)
|
||||
- `--ssl`: Enable HTTPS
|
||||
- `--ssl-certfile`: Path to SSL certificate file
|
||||
- `--ssl-keyfile`: Path to SSL private key file
|
||||
|
||||
The command line parameters and enviroment variable run_with_gunicorn.py is exactly the same as `light-server`.
|
||||
|
||||
### For Azure OpenAI Backend
|
||||
|
||||
Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
|
||||
```bash
|
||||
# Change the resource group name, location and OpenAI resource name as needed
|
||||
@@ -186,7 +221,7 @@ LightRAG supports binding to various LLM/Embedding backends:
|
||||
* openai & openai compatible
|
||||
* azure_openai
|
||||
|
||||
Use environment variables `LLM_BINDING ` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING ` or CLI argument `--embedding-binding` to select LLM backend type.
|
||||
Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type.
|
||||
|
||||
### Storage Types Supported
|
||||
|
||||
|
187
lightrag/api/gunicorn_config.py
Normal file
187
lightrag/api/gunicorn_config.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# gunicorn_config.py
|
||||
import os
|
||||
import logging
|
||||
from lightrag.kg.shared_storage import finalize_share_data
|
||||
from lightrag.api.lightrag_server import LightragPathFilter
|
||||
|
||||
# Get log directory path from environment variable
|
||||
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
||||
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
|
||||
|
||||
# Get log file max size and backup count from environment variables
|
||||
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
||||
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
||||
|
||||
# These variables will be set by run_with_gunicorn.py
|
||||
workers = None
|
||||
bind = None
|
||||
loglevel = None
|
||||
certfile = None
|
||||
keyfile = None
|
||||
|
||||
# Enable preload_app option
|
||||
preload_app = True
|
||||
|
||||
# Use Uvicorn worker
|
||||
worker_class = "uvicorn.workers.UvicornWorker"
|
||||
|
||||
# Other Gunicorn configurations
|
||||
timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py
|
||||
keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s
|
||||
|
||||
# Logging configuration
|
||||
errorlog = os.getenv("ERROR_LOG", log_file_path) # Default write to lightrag.log
|
||||
accesslog = os.getenv("ACCESS_LOG", log_file_path) # Default write to lightrag.log
|
||||
|
||||
logconfig_dict = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"},
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "standard",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"file": {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"formatter": "standard",
|
||||
"filename": log_file_path,
|
||||
"maxBytes": log_max_bytes,
|
||||
"backupCount": log_backup_count,
|
||||
"encoding": "utf8",
|
||||
},
|
||||
},
|
||||
"filters": {
|
||||
"path_filter": {
|
||||
"()": "lightrag.api.lightrag_server.LightragPathFilter",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"lightrag": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": loglevel.upper() if loglevel else "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"gunicorn": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": loglevel.upper() if loglevel else "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"gunicorn.error": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": loglevel.upper() if loglevel else "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"gunicorn.access": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": loglevel.upper() if loglevel else "INFO",
|
||||
"propagate": False,
|
||||
"filters": ["path_filter"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def on_starting(server):
|
||||
"""
|
||||
Executed when Gunicorn starts, before forking the first worker processes
|
||||
You can use this function to do more initialization tasks for all processes
|
||||
"""
|
||||
print("=" * 80)
|
||||
print(f"GUNICORN MASTER PROCESS: on_starting jobs for {workers} worker(s)")
|
||||
print(f"Process ID: {os.getpid()}")
|
||||
print("=" * 80)
|
||||
|
||||
# Memory usage monitoring
|
||||
try:
|
||||
import psutil
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
memory_info = process.memory_info()
|
||||
msg = (
|
||||
f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB"
|
||||
)
|
||||
print(msg)
|
||||
except ImportError:
|
||||
print("psutil not installed, skipping memory usage reporting")
|
||||
|
||||
print("Gunicorn initialization complete, forking workers...\n")
|
||||
|
||||
|
||||
def on_exit(server):
|
||||
"""
|
||||
Executed when Gunicorn is shutting down.
|
||||
This is a good place to release shared resources.
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("GUNICORN MASTER PROCESS: Shutting down")
|
||||
print(f"Process ID: {os.getpid()}")
|
||||
print("=" * 80)
|
||||
|
||||
# Release shared resources
|
||||
finalize_share_data()
|
||||
|
||||
print("=" * 80)
|
||||
print("Gunicorn shutdown complete")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def post_fork(server, worker):
|
||||
"""
|
||||
Executed after a worker has been forked.
|
||||
This is a good place to set up worker-specific configurations.
|
||||
"""
|
||||
# Configure formatters
|
||||
detailed_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
simple_formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||
|
||||
def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False):
|
||||
"""Set up a logger with console and file handlers"""
|
||||
logger_instance = logging.getLogger(logger_name)
|
||||
logger_instance.setLevel(level)
|
||||
logger_instance.handlers = [] # Clear existing handlers
|
||||
logger_instance.propagate = False
|
||||
|
||||
# Add console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(simple_formatter)
|
||||
console_handler.setLevel(level)
|
||||
logger_instance.addHandler(console_handler)
|
||||
|
||||
# Add file handler
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
filename=log_file_path,
|
||||
maxBytes=log_max_bytes,
|
||||
backupCount=log_backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(detailed_formatter)
|
||||
file_handler.setLevel(level)
|
||||
logger_instance.addHandler(file_handler)
|
||||
|
||||
# Add path filter if requested
|
||||
if add_filter:
|
||||
path_filter = LightragPathFilter()
|
||||
logger_instance.addFilter(path_filter)
|
||||
|
||||
# Set up main loggers
|
||||
log_level = loglevel.upper() if loglevel else "INFO"
|
||||
setup_logger("uvicorn", log_level)
|
||||
setup_logger("uvicorn.access", log_level, add_filter=True)
|
||||
setup_logger("lightrag", log_level, add_filter=True)
|
||||
|
||||
# Set up lightrag submodule loggers
|
||||
for name in logging.root.manager.loggerDict:
|
||||
if name.startswith("lightrag."):
|
||||
setup_logger(name, log_level, add_filter=True)
|
||||
|
||||
# Disable uvicorn.error logger
|
||||
uvicorn_error_logger = logging.getLogger("uvicorn.error")
|
||||
uvicorn_error_logger.handlers = []
|
||||
uvicorn_error_logger.setLevel(logging.CRITICAL)
|
||||
uvicorn_error_logger.propagate = False
|
@@ -8,11 +8,12 @@ from fastapi import (
|
||||
)
|
||||
from fastapi.responses import FileResponse
|
||||
import asyncio
|
||||
import threading
|
||||
import os
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import logging
|
||||
from typing import Dict
|
||||
import logging.config
|
||||
import uvicorn
|
||||
import pipmaster as pm
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pathlib import Path
|
||||
import configparser
|
||||
from ascii_colors import ASCIIColors
|
||||
@@ -29,7 +30,6 @@ from lightrag import LightRAG
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
from lightrag.api import __api_version__
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.utils import logger
|
||||
from .routers.document_routes import (
|
||||
DocumentManager,
|
||||
create_document_routes,
|
||||
@@ -39,33 +39,25 @@ from .routers.query_routes import create_query_routes
|
||||
from .routers.graph_routes import create_graph_routes
|
||||
from .routers.ollama_api import OllamaAPI
|
||||
|
||||
from lightrag.utils import logger, set_verbose_debug
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_pipeline_status_lock,
|
||||
initialize_pipeline_status,
|
||||
get_all_update_flags_status,
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
try:
|
||||
load_dotenv(override=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load .env file: {e}")
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Initialize config parser
|
||||
config = configparser.ConfigParser()
|
||||
config.read("config.ini")
|
||||
|
||||
# Global configuration
|
||||
global_top_k = 60 # default value
|
||||
|
||||
# Global progress tracker
|
||||
scan_progress: Dict = {
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
}
|
||||
class LightragPathFilter(logging.Filter):
|
||||
"""Filter for lightrag logger to filter out frequent path access logs"""
|
||||
|
||||
# Lock for thread-safe operations
|
||||
progress_lock = threading.Lock()
|
||||
|
||||
|
||||
class AccessLogFilter(logging.Filter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Define paths to be filtered
|
||||
@@ -73,17 +65,18 @@ class AccessLogFilter(logging.Filter):
|
||||
|
||||
def filter(self, record):
|
||||
try:
|
||||
# Check if record has the required attributes for an access log
|
||||
if not hasattr(record, "args") or not isinstance(record.args, tuple):
|
||||
return True
|
||||
if len(record.args) < 5:
|
||||
return True
|
||||
|
||||
# Extract method, path and status from the record args
|
||||
method = record.args[1]
|
||||
path = record.args[2]
|
||||
status = record.args[4]
|
||||
# print(f"Debug - Method: {method}, Path: {path}, Status: {status}")
|
||||
# print(f"Debug - Filtered paths: {self.filtered_paths}")
|
||||
|
||||
# Filter out successful GET requests to filtered paths
|
||||
if (
|
||||
method == "GET"
|
||||
and (status == 200 or status == 304)
|
||||
@@ -92,19 +85,14 @@ class AccessLogFilter(logging.Filter):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
# In case of any error, let the message through
|
||||
return True
|
||||
|
||||
|
||||
def create_app(args):
|
||||
# Set global top_k
|
||||
global global_top_k
|
||||
global_top_k = args.top_k # save top_k from args
|
||||
|
||||
# Initialize verbose debug setting
|
||||
from lightrag.utils import set_verbose_debug
|
||||
|
||||
# Setup logging
|
||||
logger.setLevel(args.log_level)
|
||||
set_verbose_debug(args.verbose)
|
||||
|
||||
# Verify that bindings are correctly setup
|
||||
@@ -138,11 +126,6 @@ def create_app(args):
|
||||
if not os.path.exists(args.ssl_keyfile):
|
||||
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
||||
)
|
||||
|
||||
# Check if API key is provided either through env var or args
|
||||
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
||||
|
||||
@@ -158,28 +141,23 @@ def create_app(args):
|
||||
try:
|
||||
# Initialize database connections
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
# Auto scan documents if enabled
|
||||
if args.auto_scan_at_startup:
|
||||
# Start scanning in background
|
||||
with progress_lock:
|
||||
if not scan_progress["is_scanning"]:
|
||||
scan_progress["is_scanning"] = True
|
||||
scan_progress["indexed_count"] = 0
|
||||
scan_progress["progress"] = 0
|
||||
# Create background task
|
||||
task = asyncio.create_task(
|
||||
run_scanning_process(rag, doc_manager)
|
||||
)
|
||||
app.state.background_tasks.add(task)
|
||||
task.add_done_callback(app.state.background_tasks.discard)
|
||||
ASCIIColors.info(
|
||||
f"Started background scanning of documents from {args.input_dir}"
|
||||
)
|
||||
else:
|
||||
ASCIIColors.info(
|
||||
"Skip document scanning(another scanning is active)"
|
||||
)
|
||||
# Check if a task is already running (with lock protection)
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
should_start_task = False
|
||||
async with get_pipeline_status_lock():
|
||||
if not pipeline_status.get("busy", False):
|
||||
should_start_task = True
|
||||
# Only start the task if no other task is running
|
||||
if should_start_task:
|
||||
# Create background task
|
||||
task = asyncio.create_task(run_scanning_process(rag, doc_manager))
|
||||
app.state.background_tasks.add(task)
|
||||
task.add_done_callback(app.state.background_tasks.discard)
|
||||
logger.info("Auto scan task started at startup.")
|
||||
|
||||
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
|
||||
|
||||
@@ -398,6 +376,9 @@ def create_app(args):
|
||||
@app.get("/health", dependencies=[Depends(optional_api_key)])
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
# Get update flags status for all namespaces
|
||||
update_status = await get_all_update_flags_status()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"working_directory": str(args.working_dir),
|
||||
@@ -417,6 +398,7 @@ def create_app(args):
|
||||
"graph_storage": args.graph_storage,
|
||||
"vector_storage": args.vector_storage,
|
||||
},
|
||||
"update_status": update_status,
|
||||
}
|
||||
|
||||
# Webui mount webui/index.html
|
||||
@@ -435,12 +417,30 @@ def create_app(args):
|
||||
return app
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
import uvicorn
|
||||
import logging.config
|
||||
def get_application(args=None):
|
||||
"""Factory function for creating the FastAPI application"""
|
||||
if args is None:
|
||||
args = parse_args()
|
||||
return create_app(args)
|
||||
|
||||
|
||||
def configure_logging():
|
||||
"""Configure logging for uvicorn startup"""
|
||||
|
||||
# Reset any existing handlers to ensure clean configuration
|
||||
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.handlers = []
|
||||
logger.filters = []
|
||||
|
||||
# Get log directory path from environment variable
|
||||
log_dir = os.getenv("LOG_DIR", os.getcwd())
|
||||
log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log"))
|
||||
|
||||
# Get log file max size and backup count from environment variables
|
||||
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
|
||||
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
|
||||
|
||||
# Configure uvicorn logging
|
||||
logging.config.dictConfig(
|
||||
{
|
||||
"version": 1,
|
||||
@@ -449,36 +449,106 @@ def main():
|
||||
"default": {
|
||||
"format": "%(levelname)s: %(message)s",
|
||||
},
|
||||
"detailed": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"console": {
|
||||
"formatter": "default",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stderr",
|
||||
},
|
||||
"file": {
|
||||
"formatter": "detailed",
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"filename": log_file_path,
|
||||
"maxBytes": log_max_bytes,
|
||||
"backupCount": log_backup_count,
|
||||
"encoding": "utf-8",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn.access": {
|
||||
"handlers": ["default"],
|
||||
# Configure all uvicorn related loggers
|
||||
"uvicorn": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
"filters": ["path_filter"],
|
||||
},
|
||||
"uvicorn.error": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"lightrag": {
|
||||
"handlers": ["console", "file"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
"filters": ["path_filter"],
|
||||
},
|
||||
},
|
||||
"filters": {
|
||||
"path_filter": {
|
||||
"()": "lightrag.api.lightrag_server.LightragPathFilter",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add filter to uvicorn access logger
|
||||
uvicorn_access_logger = logging.getLogger("uvicorn.access")
|
||||
uvicorn_access_logger.addFilter(AccessLogFilter())
|
||||
|
||||
app = create_app(args)
|
||||
def check_and_install_dependencies():
|
||||
"""Check and install required dependencies"""
|
||||
required_packages = [
|
||||
"uvicorn",
|
||||
"tiktoken",
|
||||
"fastapi",
|
||||
# Add other required packages here
|
||||
]
|
||||
|
||||
for package in required_packages:
|
||||
if not pm.is_installed(package):
|
||||
print(f"Installing {package}...")
|
||||
pm.install(package)
|
||||
print(f"{package} installed successfully")
|
||||
|
||||
|
||||
def main():
|
||||
# Check if running under Gunicorn
|
||||
if "GUNICORN_CMD_ARGS" in os.environ:
|
||||
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
||||
print("Running under Gunicorn - worker management handled by Gunicorn")
|
||||
return
|
||||
|
||||
# Check and install dependencies
|
||||
check_and_install_dependencies()
|
||||
|
||||
from multiprocessing import freeze_support
|
||||
|
||||
freeze_support()
|
||||
|
||||
# Configure logging before parsing args
|
||||
configure_logging()
|
||||
|
||||
args = parse_args(is_uvicorn_mode=True)
|
||||
display_splash_screen(args)
|
||||
|
||||
# Create application instance directly instead of using factory function
|
||||
app = create_app(args)
|
||||
|
||||
# Start Uvicorn in single process mode
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"app": app, # Pass application instance directly instead of string path
|
||||
"host": args.host,
|
||||
"port": args.port,
|
||||
"log_config": None, # Disable default config
|
||||
}
|
||||
|
||||
if args.ssl:
|
||||
uvicorn_config.update(
|
||||
{
|
||||
@@ -486,6 +556,8 @@ def main():
|
||||
"ssl_keyfile": args.ssl_keyfile,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}")
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
|
||||
|
@@ -3,8 +3,7 @@ This module contains all document-related routes for the LightRAG API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from lightrag.utils import logger
|
||||
import aiofiles
|
||||
import shutil
|
||||
import traceback
|
||||
@@ -12,7 +11,6 @@ import pipmaster as pm
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@@ -23,18 +21,6 @@ from ..utils_api import get_api_key_dependency
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
|
||||
# Global progress tracker
|
||||
scan_progress: Dict = {
|
||||
"is_scanning": False,
|
||||
"current_file": "",
|
||||
"indexed_count": 0,
|
||||
"total_files": 0,
|
||||
"progress": 0,
|
||||
}
|
||||
|
||||
# Lock for thread-safe operations
|
||||
progress_lock = asyncio.Lock()
|
||||
|
||||
# Temporary file prefix
|
||||
temp_prefix = "__tmp__"
|
||||
|
||||
@@ -161,19 +147,12 @@ class DocumentManager:
|
||||
"""Scan input directory for new files"""
|
||||
new_files = []
|
||||
for ext in self.supported_extensions:
|
||||
logging.debug(f"Scanning for {ext} files in {self.input_dir}")
|
||||
logger.debug(f"Scanning for {ext} files in {self.input_dir}")
|
||||
for file_path in self.input_dir.rglob(f"*{ext}"):
|
||||
if file_path not in self.indexed_files:
|
||||
new_files.append(file_path)
|
||||
return new_files
|
||||
|
||||
# def scan_directory(self) -> List[Path]:
|
||||
# new_files = []
|
||||
# for ext in self.supported_extensions:
|
||||
# for file_path in self.input_dir.rglob(f"*{ext}"):
|
||||
# new_files.append(file_path)
|
||||
# return new_files
|
||||
|
||||
def mark_as_indexed(self, file_path: Path):
|
||||
self.indexed_files.add(file_path)
|
||||
|
||||
@@ -287,7 +266,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
)
|
||||
content += "\n"
|
||||
case _:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Unsupported file type: {file_path.name} (extension {ext})"
|
||||
)
|
||||
return False
|
||||
@@ -295,20 +274,20 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool:
|
||||
# Insert into the RAG queue
|
||||
if content:
|
||||
await rag.apipeline_enqueue_documents(content)
|
||||
logging.info(f"Successfully fetched and enqueued file: {file_path.name}")
|
||||
logger.info(f"Successfully fetched and enqueued file: {file_path.name}")
|
||||
return True
|
||||
else:
|
||||
logging.error(f"No content could be extracted from file: {file_path.name}")
|
||||
logger.error(f"No content could be extracted from file: {file_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
if file_path.name.startswith(temp_prefix):
|
||||
try:
|
||||
file_path.unlink()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting file {file_path}: {str(e)}")
|
||||
logger.error(f"Error deleting file {file_path}: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -324,8 +303,8 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path):
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error indexing file {file_path.name}: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
|
||||
@@ -349,8 +328,8 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]):
|
||||
if enqueued:
|
||||
await rag.apipeline_process_enqueue_documents()
|
||||
except Exception as e:
|
||||
logging.error(f"Error indexing files: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error indexing files: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def pipeline_index_texts(rag: LightRAG, texts: List[str]):
|
||||
@@ -393,30 +372,17 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
|
||||
"""Background task to scan and index documents"""
|
||||
try:
|
||||
new_files = doc_manager.scan_directory_for_new_files()
|
||||
scan_progress["total_files"] = len(new_files)
|
||||
total_files = len(new_files)
|
||||
logger.info(f"Found {total_files} new files to index.")
|
||||
|
||||
logging.info(f"Found {len(new_files)} new files to index.")
|
||||
for file_path in new_files:
|
||||
for idx, file_path in enumerate(new_files):
|
||||
try:
|
||||
async with progress_lock:
|
||||
scan_progress["current_file"] = os.path.basename(file_path)
|
||||
|
||||
await pipeline_index_file(rag, file_path)
|
||||
|
||||
async with progress_lock:
|
||||
scan_progress["indexed_count"] += 1
|
||||
scan_progress["progress"] = (
|
||||
scan_progress["indexed_count"] / scan_progress["total_files"]
|
||||
) * 100
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error indexing file {file_path}: {str(e)}")
|
||||
logger.error(f"Error indexing file {file_path}: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error during scanning process: {str(e)}")
|
||||
finally:
|
||||
async with progress_lock:
|
||||
scan_progress["is_scanning"] = False
|
||||
logger.error(f"Error during scanning process: {str(e)}")
|
||||
|
||||
|
||||
def create_document_routes(
|
||||
@@ -436,34 +402,10 @@ def create_document_routes(
|
||||
Returns:
|
||||
dict: A dictionary containing the scanning status
|
||||
"""
|
||||
async with progress_lock:
|
||||
if scan_progress["is_scanning"]:
|
||||
return {"status": "already_scanning"}
|
||||
|
||||
scan_progress["is_scanning"] = True
|
||||
scan_progress["indexed_count"] = 0
|
||||
scan_progress["progress"] = 0
|
||||
|
||||
# Start the scanning process in the background
|
||||
background_tasks.add_task(run_scanning_process, rag, doc_manager)
|
||||
return {"status": "scanning_started"}
|
||||
|
||||
@router.get("/scan-progress")
|
||||
async def get_scan_progress():
|
||||
"""
|
||||
Get the current progress of the document scanning process.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the current scanning progress information including:
|
||||
- is_scanning: Whether a scan is currently in progress
|
||||
- current_file: The file currently being processed
|
||||
- indexed_count: Number of files indexed so far
|
||||
- total_files: Total number of files to process
|
||||
- progress: Percentage of completion
|
||||
"""
|
||||
async with progress_lock:
|
||||
return scan_progress
|
||||
|
||||
@router.post("/upload", dependencies=[Depends(optional_api_key)])
|
||||
async def upload_to_input_dir(
|
||||
background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||
@@ -504,8 +446,8 @@ def create_document_routes(
|
||||
message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error /documents/upload: {file.filename}: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error /documents/upload: {file.filename}: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
@@ -537,8 +479,8 @@ def create_document_routes(
|
||||
message="Text successfully received. Processing will continue in background.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error /documents/text: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error /documents/text: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
@@ -572,8 +514,8 @@ def create_document_routes(
|
||||
message="Text successfully received. Processing will continue in background.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error /documents/text: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error /documents/text: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
@@ -615,8 +557,8 @@ def create_document_routes(
|
||||
message=f"File '{file.filename}' saved successfully. Processing will continue in background.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error /documents/file: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error /documents/file: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
@@ -678,8 +620,8 @@ def create_document_routes(
|
||||
|
||||
return InsertResponse(status=status, message=status_message)
|
||||
except Exception as e:
|
||||
logging.error(f"Error /documents/batch: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error /documents/batch: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete(
|
||||
@@ -706,8 +648,42 @@ def create_document_routes(
|
||||
status="success", message="All documents cleared successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error DELETE /documents: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error DELETE /documents: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/pipeline_status", dependencies=[Depends(optional_api_key)])
|
||||
async def get_pipeline_status():
|
||||
"""
|
||||
Get the current status of the document indexing pipeline.
|
||||
|
||||
This endpoint returns information about the current state of the document processing pipeline,
|
||||
including whether it's busy, the current job name, when it started, how many documents
|
||||
are being processed, how many batches there are, and which batch is currently being processed.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the pipeline status information
|
||||
"""
|
||||
try:
|
||||
from lightrag.kg.shared_storage import get_namespace_data
|
||||
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
|
||||
# Convert to regular dict if it's a Manager.dict
|
||||
status_dict = dict(pipeline_status)
|
||||
|
||||
# Convert history_messages to a regular list if it's a Manager.list
|
||||
if "history_messages" in status_dict:
|
||||
status_dict["history_messages"] = list(status_dict["history_messages"])
|
||||
|
||||
# Format the job_start time if it exists
|
||||
if status_dict.get("job_start"):
|
||||
status_dict["job_start"] = str(status_dict["job_start"])
|
||||
|
||||
return status_dict
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pipeline status: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("", dependencies=[Depends(optional_api_key)])
|
||||
@@ -763,8 +739,8 @@ def create_document_routes(
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logging.error(f"Error GET /documents: {str(e)}")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"Error GET /documents: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return router
|
||||
|
203
lightrag/api/run_with_gunicorn.py
Normal file
203
lightrag/api/run_with_gunicorn.py
Normal file
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Start LightRAG server with Gunicorn
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import pipmaster as pm
|
||||
from lightrag.api.utils_api import parse_args, display_splash_screen
|
||||
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
||||
|
||||
|
||||
def check_and_install_dependencies():
|
||||
"""Check and install required dependencies"""
|
||||
required_packages = [
|
||||
"gunicorn",
|
||||
"tiktoken",
|
||||
"psutil",
|
||||
# Add other required packages here
|
||||
]
|
||||
|
||||
for package in required_packages:
|
||||
if not pm.is_installed(package):
|
||||
print(f"Installing {package}...")
|
||||
pm.install(package)
|
||||
print(f"{package} installed successfully")
|
||||
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig, frame):
|
||||
print("\n\n" + "=" * 80)
|
||||
print("RECEIVED TERMINATION SIGNAL")
|
||||
print(f"Process ID: {os.getpid()}")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
# Release shared resources
|
||||
finalize_share_data()
|
||||
|
||||
# Exit with success status
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
# Check and install dependencies
|
||||
check_and_install_dependencies()
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
||||
|
||||
# Parse all arguments using parse_args
|
||||
args = parse_args(is_uvicorn_mode=False)
|
||||
|
||||
# Display startup information
|
||||
display_splash_screen(args)
|
||||
|
||||
print("🚀 Starting LightRAG with Gunicorn")
|
||||
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
|
||||
print("🔍 Preloading app: Enabled")
|
||||
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
||||
print("\n\n" + "=" * 80)
|
||||
print("MAIN PROCESS INITIALIZATION")
|
||||
print(f"Process ID: {os.getpid()}")
|
||||
print(f"Workers setting: {args.workers}")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
# Import Gunicorn's StandaloneApplication
|
||||
from gunicorn.app.base import BaseApplication
|
||||
|
||||
# Define a custom application class that loads our config
|
||||
class GunicornApp(BaseApplication):
|
||||
def __init__(self, app, options=None):
|
||||
self.options = options or {}
|
||||
self.application = app
|
||||
super().__init__()
|
||||
|
||||
def load_config(self):
|
||||
# Define valid Gunicorn configuration options
|
||||
valid_options = {
|
||||
"bind",
|
||||
"workers",
|
||||
"worker_class",
|
||||
"timeout",
|
||||
"keepalive",
|
||||
"preload_app",
|
||||
"errorlog",
|
||||
"accesslog",
|
||||
"loglevel",
|
||||
"certfile",
|
||||
"keyfile",
|
||||
"limit_request_line",
|
||||
"limit_request_fields",
|
||||
"limit_request_field_size",
|
||||
"graceful_timeout",
|
||||
"max_requests",
|
||||
"max_requests_jitter",
|
||||
}
|
||||
|
||||
# Special hooks that need to be set separately
|
||||
special_hooks = {
|
||||
"on_starting",
|
||||
"on_reload",
|
||||
"on_exit",
|
||||
"pre_fork",
|
||||
"post_fork",
|
||||
"pre_exec",
|
||||
"pre_request",
|
||||
"post_request",
|
||||
"worker_init",
|
||||
"worker_exit",
|
||||
"nworkers_changed",
|
||||
"child_exit",
|
||||
}
|
||||
|
||||
# Import and configure the gunicorn_config module
|
||||
from lightrag.api import gunicorn_config
|
||||
|
||||
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
||||
gunicorn_config.workers = (
|
||||
args.workers if args.workers else int(os.getenv("WORKERS", 1))
|
||||
)
|
||||
|
||||
# Bind configuration prioritizes command line arguments
|
||||
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
|
||||
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
|
||||
gunicorn_config.bind = f"{host}:{port}"
|
||||
|
||||
# Log level configuration prioritizes command line arguments
|
||||
gunicorn_config.loglevel = (
|
||||
args.log_level.lower()
|
||||
if args.log_level
|
||||
else os.getenv("LOG_LEVEL", "info")
|
||||
)
|
||||
|
||||
# Timeout configuration prioritizes command line arguments
|
||||
gunicorn_config.timeout = (
|
||||
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
|
||||
)
|
||||
|
||||
# Keepalive configuration
|
||||
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
||||
|
||||
# SSL configuration prioritizes command line arguments
|
||||
if args.ssl or os.getenv("SSL", "").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
"t",
|
||||
"on",
|
||||
):
|
||||
gunicorn_config.certfile = (
|
||||
args.ssl_certfile
|
||||
if args.ssl_certfile
|
||||
else os.getenv("SSL_CERTFILE")
|
||||
)
|
||||
gunicorn_config.keyfile = (
|
||||
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
|
||||
)
|
||||
|
||||
# Set configuration options from the module
|
||||
for key in dir(gunicorn_config):
|
||||
if key in valid_options:
|
||||
value = getattr(gunicorn_config, key)
|
||||
# Skip functions like on_starting and None values
|
||||
if not callable(value) and value is not None:
|
||||
self.cfg.set(key, value)
|
||||
# Set special hooks
|
||||
elif key in special_hooks:
|
||||
value = getattr(gunicorn_config, key)
|
||||
if callable(value):
|
||||
self.cfg.set(key, value)
|
||||
|
||||
if hasattr(gunicorn_config, "logconfig_dict"):
|
||||
self.cfg.set(
|
||||
"logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
|
||||
)
|
||||
|
||||
def load(self):
|
||||
# Import the application
|
||||
from lightrag.api.lightrag_server import get_application
|
||||
|
||||
return get_application(args)
|
||||
|
||||
# Create the application
|
||||
app = GunicornApp("")
|
||||
|
||||
# Force workers to be an integer and greater than 1 for multi-process mode
|
||||
workers_count = int(args.workers)
|
||||
if workers_count > 1:
|
||||
# Set a flag to indicate we're in the main process
|
||||
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
||||
initialize_share_data(workers_count)
|
||||
else:
|
||||
initialize_share_data(1)
|
||||
|
||||
# Run the application
|
||||
print("\nStarting Gunicorn with direct Python API...")
|
||||
app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -6,6 +6,7 @@ import os
|
||||
import argparse
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
from ascii_colors import ASCIIColors
|
||||
from lightrag.api import __api_version__
|
||||
from fastapi import HTTPException, Security
|
||||
@@ -110,10 +111,13 @@ def get_env_value(env_key: str, default: any, value_type: type = str) -> any:
|
||||
return default
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
||||
"""
|
||||
Parse command line arguments with environment variable fallback
|
||||
|
||||
Args:
|
||||
is_uvicorn_mode: Whether running under uvicorn mode
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
@@ -260,6 +264,14 @@ def parse_args() -> argparse.Namespace:
|
||||
help="Enable automatic scanning when the program starts",
|
||||
)
|
||||
|
||||
# Server workers configuration
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=get_env_value("WORKERS", 1, int),
|
||||
help="Number of worker processes (default: from env or 1)",
|
||||
)
|
||||
|
||||
# LLM and embedding bindings
|
||||
parser.add_argument(
|
||||
"--llm-binding",
|
||||
@@ -278,6 +290,15 @@ def parse_args() -> argparse.Namespace:
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# If in uvicorn mode and workers > 1, force it to 1 and log warning
|
||||
if is_uvicorn_mode and args.workers > 1:
|
||||
original_workers = args.workers
|
||||
args.workers = 1
|
||||
# Log warning directly here
|
||||
logging.warning(
|
||||
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
|
||||
)
|
||||
|
||||
# convert relative path to absolute path
|
||||
args.working_dir = os.path.abspath(args.working_dir)
|
||||
args.input_dir = os.path.abspath(args.input_dir)
|
||||
@@ -346,17 +367,27 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.host}")
|
||||
ASCIIColors.white(" ├─ Port: ", end="")
|
||||
ASCIIColors.yellow(f"{args.port}")
|
||||
ASCIIColors.white(" ├─ Workers: ", end="")
|
||||
ASCIIColors.yellow(f"{args.workers}")
|
||||
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
||||
ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}")
|
||||
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
||||
ASCIIColors.yellow(f"{args.ssl}")
|
||||
ASCIIColors.white(" └─ API Key: ", end="")
|
||||
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
||||
if args.ssl:
|
||||
ASCIIColors.white(" ├─ SSL Cert: ", end="")
|
||||
ASCIIColors.yellow(f"{args.ssl_certfile}")
|
||||
ASCIIColors.white(" └─ SSL Key: ", end="")
|
||||
ASCIIColors.white(" ├─ SSL Key: ", end="")
|
||||
ASCIIColors.yellow(f"{args.ssl_keyfile}")
|
||||
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
|
||||
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
||||
ASCIIColors.white(" ├─ Log Level: ", end="")
|
||||
ASCIIColors.yellow(f"{args.log_level}")
|
||||
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
|
||||
ASCIIColors.yellow(f"{args.verbose}")
|
||||
ASCIIColors.white(" ├─ Timeout: ", end="")
|
||||
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
||||
ASCIIColors.white(" └─ API Key: ", end="")
|
||||
ASCIIColors.yellow("Set" if args.key else "Not Set")
|
||||
|
||||
# Directory Configuration
|
||||
ASCIIColors.magenta("\n📂 Directory Configuration:")
|
||||
@@ -415,16 +446,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.white(" └─ Document Status Storage: ", end="")
|
||||
ASCIIColors.yellow(f"{args.doc_status_storage}")
|
||||
|
||||
ASCIIColors.magenta("\n🛠️ System Configuration:")
|
||||
ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="")
|
||||
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
|
||||
ASCIIColors.white(" ├─ Log Level: ", end="")
|
||||
ASCIIColors.yellow(f"{args.log_level}")
|
||||
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
|
||||
ASCIIColors.yellow(f"{args.verbose}")
|
||||
ASCIIColors.white(" └─ Timeout: ", end="")
|
||||
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
||||
|
||||
# Server Status
|
||||
ASCIIColors.green("\n✨ Server starting up...\n")
|
||||
|
||||
@@ -478,7 +499,6 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.cyan(""" 3. Basic Operations:
|
||||
- POST /upload_document: Upload new documents to RAG
|
||||
- POST /query: Query your document collection
|
||||
- GET /collections: List available collections
|
||||
|
||||
4. Monitor the server:
|
||||
- Check server logs for detailed operation information
|
||||
|
@@ -2,25 +2,25 @@ import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Any, final
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
compute_mdhash_id,
|
||||
)
|
||||
from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
from lightrag.base import BaseVectorStorage
|
||||
|
||||
if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
import faiss
|
||||
import faiss # type: ignore
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@@ -55,14 +55,40 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
# If you have a large number of vectors, you might want IVF or other indexes.
|
||||
# For demonstration, we use a simple IndexFlatIP.
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
|
||||
# Keep a local store for metadata, IDs, etc.
|
||||
# Maps <int faiss_id> → metadata (including your original ID).
|
||||
self._id_to_meta = {}
|
||||
|
||||
# Attempt to load an existing index + metadata from disk
|
||||
self._load_faiss_index()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
async def _get_index(self):
|
||||
"""Check if the shtorage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if (is_multiprocess and self.storage_updated.value) or (
|
||||
not is_multiprocess and self.storage_updated
|
||||
):
|
||||
logger.info(
|
||||
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
|
||||
)
|
||||
# Reload data
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
return self._index
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Insert or update vectors in the Faiss index.
|
||||
@@ -113,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return []
|
||||
|
||||
# Normalize embeddings for cosine similarity (in-place)
|
||||
# Convert to float32 and normalize embeddings for cosine similarity (in-place)
|
||||
embeddings = embeddings.astype(np.float32)
|
||||
faiss.normalize_L2(embeddings)
|
||||
|
||||
# Upsert logic:
|
||||
@@ -127,18 +154,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
existing_ids_to_remove.append(faiss_internal_id)
|
||||
|
||||
if existing_ids_to_remove:
|
||||
self._remove_faiss_ids(existing_ids_to_remove)
|
||||
await self._remove_faiss_ids(existing_ids_to_remove)
|
||||
|
||||
# Step 2: Add new vectors
|
||||
start_idx = self._index.ntotal
|
||||
self._index.add(embeddings)
|
||||
index = await self._get_index()
|
||||
start_idx = index.ntotal
|
||||
index.add(embeddings)
|
||||
|
||||
# Step 3: Store metadata + vector for each new ID
|
||||
for i, meta in enumerate(list_data):
|
||||
fid = start_idx + i
|
||||
# Store the raw vector so we can rebuild if something is removed
|
||||
meta["__vector__"] = embeddings[i].tolist()
|
||||
self._id_to_meta[fid] = meta
|
||||
self._id_to_meta.update({fid: meta})
|
||||
|
||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||
return [m["__id__"] for m in list_data]
|
||||
@@ -157,7 +185,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
# Perform the similarity search
|
||||
distances, indices = self._index.search(embedding, top_k)
|
||||
index = await self._get_index()
|
||||
distances, indices = index.search(embedding, top_k)
|
||||
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
@@ -201,8 +230,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
to_remove.append(fid)
|
||||
|
||||
if to_remove:
|
||||
self._remove_faiss_ids(to_remove)
|
||||
logger.info(
|
||||
await self._remove_faiss_ids(to_remove)
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
||||
)
|
||||
|
||||
@@ -223,12 +252,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
||||
if relations:
|
||||
self._remove_faiss_ids(relations)
|
||||
await self._remove_faiss_ids(relations)
|
||||
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
self._save_faiss_index()
|
||||
|
||||
# --------------------------------------------------------------------------------
|
||||
# Internal helper methods
|
||||
# --------------------------------------------------------------------------------
|
||||
@@ -242,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
return fid
|
||||
return None
|
||||
|
||||
def _remove_faiss_ids(self, fid_list):
|
||||
async def _remove_faiss_ids(self, fid_list):
|
||||
"""
|
||||
Remove a list of internal Faiss IDs from the index.
|
||||
Because IndexFlatIP doesn't support 'removals',
|
||||
@@ -258,13 +284,14 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
||||
new_id_to_meta[new_fid] = vec_meta
|
||||
|
||||
# Re-init index
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
if vectors_to_keep:
|
||||
arr = np.array(vectors_to_keep, dtype=np.float32)
|
||||
self._index.add(arr)
|
||||
async with self._storage_lock:
|
||||
# Re-init index
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
if vectors_to_keep:
|
||||
arr = np.array(vectors_to_keep, dtype=np.float32)
|
||||
self._index.add(arr)
|
||||
|
||||
self._id_to_meta = new_id_to_meta
|
||||
self._id_to_meta = new_id_to_meta
|
||||
|
||||
def _save_faiss_index(self):
|
||||
"""
|
||||
@@ -312,3 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
logger.warning("Starting with an empty Faiss index.")
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
async with self._storage_lock:
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
try:
|
||||
# Save data to disk
|
||||
self._save_faiss_index()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
|
||||
return False # Return error
|
||||
|
||||
return True # Return success
|
||||
|
@@ -12,6 +12,11 @@ from lightrag.utils import (
|
||||
logger,
|
||||
write_json,
|
||||
)
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_lock,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@@ -22,26 +27,42 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._data = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(
|
||||
f"Loaded document status storage with {len(loaded_data)} records"
|
||||
)
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
return set(keys) - set(self._data.keys())
|
||||
async with self._storage_lock:
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
result: list[dict[str, Any]] = []
|
||||
for id in ids:
|
||||
data = self._data.get(id, None)
|
||||
if data:
|
||||
result.append(data)
|
||||
async with self._storage_lock:
|
||||
for id in ids:
|
||||
data = self._data.get(id, None)
|
||||
if data:
|
||||
result.append(data)
|
||||
return result
|
||||
|
||||
async def get_status_counts(self) -> dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
counts = {status.value: 0 for status in DocStatus}
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
async with self._storage_lock:
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
return counts
|
||||
|
||||
async def get_docs_by_status(
|
||||
@@ -49,39 +70,48 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
) -> dict[str, DocProcessingStatus]:
|
||||
"""Get all documents with a specific status"""
|
||||
result = {}
|
||||
for k, v in self._data.items():
|
||||
if v["status"] == status.value:
|
||||
try:
|
||||
# Make a copy of the data to avoid modifying the original
|
||||
data = v.copy()
|
||||
# If content is missing, use content_summary as content
|
||||
if "content" not in data and "content_summary" in data:
|
||||
data["content"] = data["content_summary"]
|
||||
result[k] = DocProcessingStatus(**data)
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing required field for document {k}: {e}")
|
||||
continue
|
||||
async with self._storage_lock:
|
||||
for k, v in self._data.items():
|
||||
if v["status"] == status.value:
|
||||
try:
|
||||
# Make a copy of the data to avoid modifying the original
|
||||
data = v.copy()
|
||||
# If content is missing, use content_summary as content
|
||||
if "content" not in data and "content_summary" in data:
|
||||
data["content"] = data["content_summary"]
|
||||
result[k] = DocProcessingStatus(**data)
|
||||
except KeyError as e:
|
||||
logger.error(f"Missing required field for document {k}: {e}")
|
||||
continue
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
async with self._storage_lock:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
|
||||
self._data.update(data)
|
||||
async with self._storage_lock:
|
||||
self._data.update(data)
|
||||
await self.index_done_callback()
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
return self._data.get(id)
|
||||
async with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
|
||||
async def delete(self, doc_ids: list[str]):
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
async with self._storage_lock:
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await self.index_done_callback()
|
||||
|
||||
async def drop(self) -> None:
|
||||
"""Drop the storage"""
|
||||
self._data.clear()
|
||||
async with self._storage_lock:
|
||||
self._data.clear()
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
@@ -11,6 +10,11 @@ from lightrag.utils import (
|
||||
logger,
|
||||
write_json,
|
||||
)
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_lock,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@@ -19,37 +23,56 @@ class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
self._lock = asyncio.Lock()
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
self._storage_lock = get_storage_lock()
|
||||
self._data = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = try_initialize_namespace(self.namespace)
|
||||
self._data = await get_namespace_data(self.namespace)
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
async with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
write_json(self._data, self._file_name)
|
||||
async with self._storage_lock:
|
||||
data_dict = (
|
||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||
)
|
||||
write_json(data_dict, self._file_name)
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
return self._data.get(id)
|
||||
async with self._storage_lock:
|
||||
return self._data.get(id)
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items()}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
async with self._storage_lock:
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items()}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
return set(keys) - set(self._data.keys())
|
||||
async with self._storage_lock:
|
||||
return set(keys) - set(self._data.keys())
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
return
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
async with self._storage_lock:
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
|
||||
async def delete(self, ids: list[str]) -> None:
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
async with self._storage_lock:
|
||||
for doc_id in ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await self.index_done_callback()
|
||||
|
@@ -3,7 +3,6 @@ import os
|
||||
from typing import Any, final
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
import time
|
||||
|
||||
from lightrag.utils import (
|
||||
@@ -11,22 +10,29 @@ from lightrag.utils import (
|
||||
compute_mdhash_id,
|
||||
)
|
||||
import pipmaster as pm
|
||||
from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
from lightrag.base import BaseVectorStorage
|
||||
|
||||
if not pm.is_installed("nano-vectordb"):
|
||||
pm.install("nano-vectordb")
|
||||
|
||||
from nano_vectordb import NanoVectorDB
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
def __post_init__(self):
|
||||
# Initialize lock only for file operations
|
||||
self._save_lock = asyncio.Lock()
|
||||
# Initialize basic attributes
|
||||
self._client = None
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
|
||||
# Use global config value if specified, otherwise use default
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
@@ -40,10 +46,43 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
async def _get_client(self):
|
||||
"""Check if the storage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if data needs to be reloaded
|
||||
if (is_multiprocess and self.storage_updated.value) or (
|
||||
not is_multiprocess and self.storage_updated
|
||||
):
|
||||
logger.info(
|
||||
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
|
||||
)
|
||||
# Reload data
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
|
||||
return self._client
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} to {self.namespace}")
|
||||
if not data:
|
||||
@@ -64,6 +103,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
@@ -71,7 +111,8 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
if len(embeddings) == len(list_data):
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
client = await self._get_client()
|
||||
results = client.upsert(datas=list_data)
|
||||
return results
|
||||
else:
|
||||
# sometimes the embedding is not returned correctly. just log it.
|
||||
@@ -80,9 +121,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
results = self._client.query(
|
||||
|
||||
client = await self._get_client()
|
||||
results = client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
@@ -99,8 +143,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
@property
|
||||
def client_storage(self):
|
||||
return getattr(self._client, "_NanoVectorDB__storage")
|
||||
async def client_storage(self):
|
||||
client = await self._get_client()
|
||||
return getattr(client, "_NanoVectorDB__storage")
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
@@ -109,8 +154,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
try:
|
||||
self._client.delete(ids)
|
||||
logger.info(
|
||||
client = await self._get_client()
|
||||
client.delete(ids)
|
||||
logger.debug(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -122,9 +168,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug(
|
||||
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
|
||||
# Check if the entity exists
|
||||
if self._client.get([entity_id]):
|
||||
await self.delete([entity_id])
|
||||
client = await self._get_client()
|
||||
if client.get([entity_id]):
|
||||
client.delete([entity_id])
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
else:
|
||||
logger.debug(f"Entity {entity_name} not found in storage")
|
||||
@@ -133,16 +181,19 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
try:
|
||||
client = await self._get_client()
|
||||
storage = getattr(client, "_NanoVectorDB__storage")
|
||||
relations = [
|
||||
dp
|
||||
for dp in self.client_storage["data"]
|
||||
for dp in storage["data"]
|
||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||
]
|
||||
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
|
||||
if ids_to_delete:
|
||||
await self.delete(ids_to_delete)
|
||||
client = await self._get_client()
|
||||
client.delete(ids_to_delete)
|
||||
logger.debug(
|
||||
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
||||
)
|
||||
@@ -151,6 +202,37 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
async with self._save_lock:
|
||||
self._client.save()
|
||||
async def index_done_callback(self) -> bool:
|
||||
"""Save data to disk"""
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Storage for {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
try:
|
||||
# Save data to disk
|
||||
self._client.save()
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
return True # Return success
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving data for {self.namespace}: {e}")
|
||||
return False # Return error
|
||||
|
||||
return True # Return success
|
||||
|
@@ -1,18 +1,12 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, final
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
)
|
||||
from lightrag.utils import logger
|
||||
from lightrag.base import BaseGraphStorage
|
||||
|
||||
from lightrag.base import (
|
||||
BaseGraphStorage,
|
||||
)
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("networkx"):
|
||||
@@ -23,6 +17,12 @@ if not pm.is_installed("graspologic"):
|
||||
|
||||
import networkx as nx
|
||||
from graspologic import embed
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@@ -78,56 +78,101 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
self._storage_lock = None
|
||||
self.storage_updated = None
|
||||
self._graph = None
|
||||
|
||||
# Load initial graph
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
else:
|
||||
logger.info("Created new empty graph")
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
async def _get_graph(self):
|
||||
"""Check if the storage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
async with self._storage_lock:
|
||||
# Check if data needs to be reloaded
|
||||
if (is_multiprocess and self.storage_updated.value) or (
|
||||
not is_multiprocess and self.storage_updated
|
||||
):
|
||||
logger.info(
|
||||
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
|
||||
)
|
||||
# Reload data
|
||||
self._graph = (
|
||||
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
)
|
||||
# Reset update flag
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
|
||||
return self._graph
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self._graph.has_node(node_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
return self._graph.nodes.get(node_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.nodes.get(node_id)
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
return self._graph.degree(node_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.degree(node_id)
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
|
||||
graph = await self._get_graph()
|
||||
return graph.degree(src_id) + graph.degree(tgt_id)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> dict[str, str] | None:
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
graph = await self._get_graph()
|
||||
return graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(source_node_id):
|
||||
return list(graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
graph = await self._get_graph()
|
||||
graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
) -> None:
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
graph = await self._get_graph()
|
||||
graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
if self._graph.has_node(node_id):
|
||||
self._graph.remove_node(node_id)
|
||||
logger.info(f"Node {node_id} deleted from the graph.")
|
||||
graph = await self._get_graph()
|
||||
if graph.has_node(node_id):
|
||||
graph.remove_node(node_id)
|
||||
logger.debug(f"Node {node_id} deleted from the graph.")
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
@@ -138,35 +183,37 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
# @TODO: NOT USED
|
||||
# TODO: NOT USED
|
||||
async def _node2vec_embed(self):
|
||||
graph = await self._get_graph()
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
def remove_nodes(self, nodes: list[str]):
|
||||
async def remove_nodes(self, nodes: list[str]):
|
||||
"""Delete multiple nodes
|
||||
|
||||
Args:
|
||||
nodes: List of node IDs to be deleted
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
for node in nodes:
|
||||
if self._graph.has_node(node):
|
||||
self._graph.remove_node(node)
|
||||
if graph.has_node(node):
|
||||
graph.remove_node(node)
|
||||
|
||||
def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
async def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
|
||||
Args:
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
for source, target in edges:
|
||||
if self._graph.has_edge(source, target):
|
||||
self._graph.remove_edge(source, target)
|
||||
if graph.has_edge(source, target):
|
||||
graph.remove_edge(source, target)
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
@@ -174,8 +221,9 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
[label1, label2, ...] # Alphabetically sorted label list
|
||||
"""
|
||||
graph = await self._get_graph()
|
||||
labels = set()
|
||||
for node in self._graph.nodes():
|
||||
for node in graph.nodes():
|
||||
labels.add(str(node)) # Add node id as a label
|
||||
|
||||
# Return sorted list
|
||||
@@ -198,16 +246,18 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
graph = await self._get_graph()
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
subgraph = (
|
||||
self._graph.copy()
|
||||
graph.copy()
|
||||
) # Create a copy to avoid modifying the original graph
|
||||
else:
|
||||
# Find nodes with matching node id (partial match)
|
||||
nodes_to_explore = []
|
||||
for n, attr in self._graph.nodes(data=True):
|
||||
for n, attr in graph.nodes(data=True):
|
||||
if node_label in str(n): # Use partial matching
|
||||
nodes_to_explore.append(n)
|
||||
|
||||
@@ -216,7 +266,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
return result
|
||||
|
||||
# Get subgraph using ego_graph
|
||||
subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth)
|
||||
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
|
||||
|
||||
# Check if number of nodes exceeds max_graph_nodes
|
||||
max_graph_nodes = 500
|
||||
@@ -278,9 +328,41 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
# logger.info(result.edges)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
return result
|
||||
|
||||
async def index_done_callback(self) -> bool:
|
||||
"""Save data to disk"""
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(
|
||||
f"Graph for {self.namespace} was updated by another process, reloading..."
|
||||
)
|
||||
self._graph = (
|
||||
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||
)
|
||||
# Reset update flag
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
try:
|
||||
# Save data to disk
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
# Notify other processes that data has been updated
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
return True # Return success
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving graph for {self.namespace}: {e}")
|
||||
return False # Return error
|
||||
|
||||
return True
|
||||
|
@@ -38,8 +38,8 @@ import pipmaster as pm
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
import asyncpg
|
||||
from asyncpg import Pool
|
||||
import asyncpg # type: ignore
|
||||
from asyncpg import Pool # type: ignore
|
||||
|
||||
|
||||
class PostgreSQLDB:
|
||||
|
374
lightrag/kg/shared_storage.py
Normal file
374
lightrag/kg/shared_storage.py
Normal file
@@ -0,0 +1,374 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
from multiprocessing.synchronize import Lock as ProcessLock
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
||||
|
||||
|
||||
# Define a direct print function for critical logs that must be visible in all processes
|
||||
def direct_log(message, level="INFO"):
|
||||
"""
|
||||
Log a message directly to stderr to ensure visibility in all processes,
|
||||
including the Gunicorn master process.
|
||||
"""
|
||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
LockType = Union[ProcessLock, asyncio.Lock]
|
||||
|
||||
is_multiprocess = None
|
||||
_workers = None
|
||||
_manager = None
|
||||
_initialized = None
|
||||
|
||||
# shared data for storage across processes
|
||||
_shared_dicts: Optional[Dict[str, Any]] = None
|
||||
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
||||
_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
|
||||
|
||||
# locks for mutex access
|
||||
_storage_lock: Optional[LockType] = None
|
||||
_internal_lock: Optional[LockType] = None
|
||||
_pipeline_status_lock: Optional[LockType] = None
|
||||
|
||||
|
||||
class UnifiedLock(Generic[T]):
|
||||
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
||||
|
||||
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
||||
self._lock = lock
|
||||
self._is_async = is_async
|
||||
|
||||
async def __aenter__(self) -> "UnifiedLock[T]":
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._is_async:
|
||||
self._lock.release()
|
||||
else:
|
||||
self._lock.release()
|
||||
|
||||
def __enter__(self) -> "UnifiedLock[T]":
|
||||
"""For backward compatibility"""
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""For backward compatibility"""
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
self._lock.release()
|
||||
|
||||
|
||||
def get_internal_lock() -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
|
||||
|
||||
|
||||
def get_storage_lock() -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
|
||||
|
||||
|
||||
def get_pipeline_status_lock() -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
|
||||
|
||||
|
||||
def initialize_share_data(workers: int = 1):
|
||||
"""
|
||||
Initialize shared storage data for single or multi-process mode.
|
||||
|
||||
When used with Gunicorn's preload feature, this function is called once in the
|
||||
master process before forking worker processes, allowing all workers to share
|
||||
the same initialized data.
|
||||
|
||||
In single-process mode, this function is called in FASTAPI lifespan function.
|
||||
|
||||
The function determines whether to use cross-process shared variables for data storage
|
||||
based on the number of workers. If workers=1, it uses thread locks and local dictionaries.
|
||||
If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager.
|
||||
|
||||
Args:
|
||||
workers (int): Number of worker processes. If 1, single-process mode is used.
|
||||
If > 1, multi-process mode with shared memory is used.
|
||||
"""
|
||||
global \
|
||||
_manager, \
|
||||
_workers, \
|
||||
is_multiprocess, \
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_shared_dicts, \
|
||||
_init_flags, \
|
||||
_initialized, \
|
||||
_update_flags
|
||||
|
||||
# Check if already initialized
|
||||
if _initialized:
|
||||
direct_log(
|
||||
f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})"
|
||||
)
|
||||
return
|
||||
|
||||
_manager = Manager()
|
||||
_workers = workers
|
||||
|
||||
if workers > 1:
|
||||
is_multiprocess = True
|
||||
_internal_lock = _manager.Lock()
|
||||
_storage_lock = _manager.Lock()
|
||||
_pipeline_status_lock = _manager.Lock()
|
||||
_shared_dicts = _manager.dict()
|
||||
_init_flags = _manager.dict()
|
||||
_update_flags = _manager.dict()
|
||||
direct_log(
|
||||
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
|
||||
)
|
||||
else:
|
||||
is_multiprocess = False
|
||||
_internal_lock = asyncio.Lock()
|
||||
_storage_lock = asyncio.Lock()
|
||||
_pipeline_status_lock = asyncio.Lock()
|
||||
_shared_dicts = {}
|
||||
_init_flags = {}
|
||||
_update_flags = {}
|
||||
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
||||
|
||||
# Mark as initialized
|
||||
_initialized = True
|
||||
|
||||
|
||||
async def initialize_pipeline_status():
|
||||
"""
|
||||
Initialize pipeline namespace with default values.
|
||||
This function is called during FASTAPI lifespan for each worker.
|
||||
"""
|
||||
pipeline_namespace = await get_namespace_data("pipeline_status")
|
||||
|
||||
async with get_internal_lock():
|
||||
# Check if already initialized by checking for required fields
|
||||
if "busy" in pipeline_namespace:
|
||||
return
|
||||
|
||||
# Create a shared list object for history_messages
|
||||
history_messages = _manager.list() if is_multiprocess else []
|
||||
pipeline_namespace.update(
|
||||
{
|
||||
"busy": False, # Control concurrent processes
|
||||
"job_name": "Default Job", # Current job name (indexing files/indexing texts)
|
||||
"job_start": None, # Job start time
|
||||
"docs": 0, # Total number of documents to be indexed
|
||||
"batchs": 0, # Number of batches for processing documents
|
||||
"cur_batch": 0, # Current processing batch
|
||||
"request_pending": False, # Flag for pending request for processing
|
||||
"latest_message": "", # Latest message from pipeline processing
|
||||
"history_messages": history_messages, # 使用共享列表对象
|
||||
}
|
||||
)
|
||||
direct_log(f"Process {os.getpid()} Pipeline namespace initialized")
|
||||
|
||||
|
||||
async def get_update_flag(namespace: str):
|
||||
"""
|
||||
Create a namespace's update flag for a workers.
|
||||
Returen the update flag to caller for referencing or reset.
|
||||
"""
|
||||
global _update_flags
|
||||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _update_flags:
|
||||
if is_multiprocess and _manager is not None:
|
||||
_update_flags[namespace] = _manager.list()
|
||||
else:
|
||||
_update_flags[namespace] = []
|
||||
direct_log(
|
||||
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
|
||||
)
|
||||
|
||||
if is_multiprocess and _manager is not None:
|
||||
new_update_flag = _manager.Value("b", False)
|
||||
else:
|
||||
new_update_flag = False
|
||||
|
||||
_update_flags[namespace].append(new_update_flag)
|
||||
return new_update_flag
|
||||
|
||||
|
||||
async def set_all_update_flags(namespace: str):
|
||||
"""Set all update flag of namespace indicating all workers need to reload data from files"""
|
||||
global _update_flags
|
||||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
# Update flags for both modes
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
if is_multiprocess:
|
||||
_update_flags[namespace][i].value = True
|
||||
else:
|
||||
_update_flags[namespace][i] = True
|
||||
|
||||
|
||||
async def get_all_update_flags_status() -> Dict[str, list]:
|
||||
"""
|
||||
Get update flags status for all namespaces.
|
||||
|
||||
Returns:
|
||||
Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses
|
||||
"""
|
||||
if _update_flags is None:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
async with get_internal_lock():
|
||||
for namespace, flags in _update_flags.items():
|
||||
worker_statuses = []
|
||||
for flag in flags:
|
||||
if is_multiprocess:
|
||||
worker_statuses.append(flag.value)
|
||||
else:
|
||||
worker_statuses.append(flag)
|
||||
result[namespace] = worker_statuses
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def try_initialize_namespace(namespace: str) -> bool:
|
||||
"""
|
||||
Returns True if the current worker(process) gets initialization permission for loading data later.
|
||||
The worker does not get the permission is prohibited to load data from files.
|
||||
"""
|
||||
global _init_flags, _manager
|
||||
|
||||
if _init_flags is None:
|
||||
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
||||
|
||||
if namespace not in _init_flags:
|
||||
_init_flags[namespace] = True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
||||
)
|
||||
return True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||
"""get the shared data reference for specific namespace"""
|
||||
if _shared_dicts is None:
|
||||
direct_log(
|
||||
f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}",
|
||||
level="ERROR",
|
||||
)
|
||||
raise ValueError("Shared dictionaries not initialized")
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _shared_dicts:
|
||||
if is_multiprocess and _manager is not None:
|
||||
_shared_dicts[namespace] = _manager.dict()
|
||||
else:
|
||||
_shared_dicts[namespace] = {}
|
||||
|
||||
return _shared_dicts[namespace]
|
||||
|
||||
|
||||
def finalize_share_data():
|
||||
"""
|
||||
Release shared resources and clean up.
|
||||
|
||||
This function should be called when the application is shutting down
|
||||
to properly release shared resources and avoid memory leaks.
|
||||
|
||||
In multi-process mode, it shuts down the Manager and releases all shared objects.
|
||||
In single-process mode, it simply resets the global variables.
|
||||
"""
|
||||
global \
|
||||
_manager, \
|
||||
is_multiprocess, \
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_shared_dicts, \
|
||||
_init_flags, \
|
||||
_initialized, \
|
||||
_update_flags
|
||||
|
||||
# Check if already initialized
|
||||
if not _initialized:
|
||||
direct_log(
|
||||
f"Process {os.getpid()} storage data not initialized, nothing to finalize"
|
||||
)
|
||||
return
|
||||
|
||||
direct_log(
|
||||
f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})"
|
||||
)
|
||||
|
||||
# In multi-process mode, shut down the Manager
|
||||
if is_multiprocess and _manager is not None:
|
||||
try:
|
||||
# Clear shared resources before shutting down Manager
|
||||
if _shared_dicts is not None:
|
||||
# Clear pipeline status history messages first if exists
|
||||
try:
|
||||
pipeline_status = _shared_dicts.get("pipeline_status", {})
|
||||
if "history_messages" in pipeline_status:
|
||||
pipeline_status["history_messages"].clear()
|
||||
except Exception:
|
||||
pass # Ignore any errors during history messages cleanup
|
||||
_shared_dicts.clear()
|
||||
if _init_flags is not None:
|
||||
_init_flags.clear()
|
||||
if _update_flags is not None:
|
||||
# Clear each namespace's update flags list and Value objects
|
||||
try:
|
||||
for namespace in _update_flags:
|
||||
flags_list = _update_flags[namespace]
|
||||
if isinstance(flags_list, list):
|
||||
# Clear Value objects in the list
|
||||
for flag in flags_list:
|
||||
if hasattr(
|
||||
flag, "value"
|
||||
): # Check if it's a Value object
|
||||
flag.value = False
|
||||
flags_list.clear()
|
||||
except Exception:
|
||||
pass # Ignore any errors during update flags cleanup
|
||||
_update_flags.clear()
|
||||
|
||||
# Shut down the Manager - this will automatically clean up all shared resources
|
||||
_manager.shutdown()
|
||||
direct_log(f"Process {os.getpid()} Manager shutdown complete")
|
||||
except Exception as e:
|
||||
direct_log(
|
||||
f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR"
|
||||
)
|
||||
|
||||
# Reset global variables
|
||||
_manager = None
|
||||
_initialized = None
|
||||
is_multiprocess = None
|
||||
_shared_dicts = None
|
||||
_init_flags = None
|
||||
_storage_lock = None
|
||||
_internal_lock = None
|
||||
_pipeline_status_lock = None
|
||||
_update_flags = None
|
||||
|
||||
direct_log(f"Process {os.getpid()} storage data finalization complete")
|
@@ -45,7 +45,6 @@ from .utils import (
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
logger,
|
||||
set_logger,
|
||||
)
|
||||
from .types import KnowledgeGraph
|
||||
from dotenv import load_dotenv
|
||||
@@ -268,9 +267,14 @@ class LightRAG:
|
||||
|
||||
def __post_init__(self):
|
||||
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
|
||||
set_logger(self.log_file_path, self.log_level)
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
|
||||
from lightrag.kg.shared_storage import (
|
||||
initialize_share_data,
|
||||
)
|
||||
|
||||
initialize_share_data()
|
||||
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
@@ -692,117 +696,221 @@ class LightRAG:
|
||||
3. Process each chunk for entity and relation extraction
|
||||
4. Update the document status
|
||||
"""
|
||||
# 1. Get all pending, failed, and abnormally terminated processing documents.
|
||||
# Run the asynchronous status retrievals in parallel using asyncio.gather
|
||||
processing_docs, failed_docs, pending_docs = await asyncio.gather(
|
||||
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
||||
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
||||
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_pipeline_status_lock,
|
||||
)
|
||||
|
||||
to_process_docs: dict[str, DocProcessingStatus] = {}
|
||||
to_process_docs.update(processing_docs)
|
||||
to_process_docs.update(failed_docs)
|
||||
to_process_docs.update(pending_docs)
|
||||
# Get pipeline status shared data and lock
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
pipeline_status_lock = get_pipeline_status_lock()
|
||||
|
||||
if not to_process_docs:
|
||||
logger.info("All documents have been processed or are duplicates")
|
||||
return
|
||||
# Check if another process is already processing the queue
|
||||
async with pipeline_status_lock:
|
||||
# Ensure only one worker is processing documents
|
||||
if not pipeline_status.get("busy", False):
|
||||
# 先检查是否有需要处理的文档
|
||||
processing_docs, failed_docs, pending_docs = await asyncio.gather(
|
||||
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
||||
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
||||
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
||||
)
|
||||
|
||||
# 2. split docs into chunks, insert chunks, update doc status
|
||||
docs_batches = [
|
||||
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
||||
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
||||
]
|
||||
to_process_docs: dict[str, DocProcessingStatus] = {}
|
||||
to_process_docs.update(processing_docs)
|
||||
to_process_docs.update(failed_docs)
|
||||
to_process_docs.update(pending_docs)
|
||||
|
||||
logger.info(f"Number of batches to process: {len(docs_batches)}.")
|
||||
# 如果没有需要处理的文档,直接返回,保留 pipeline_status 中的内容不变
|
||||
if not to_process_docs:
|
||||
logger.info("No documents to process")
|
||||
return
|
||||
|
||||
batches: list[Any] = []
|
||||
# 3. iterate over batches
|
||||
for batch_idx, docs_batch in enumerate(docs_batches):
|
||||
|
||||
async def batch(
|
||||
batch_idx: int,
|
||||
docs_batch: list[tuple[str, DocProcessingStatus]],
|
||||
size_batch: int,
|
||||
) -> None:
|
||||
logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.")
|
||||
# 4. iterate over batch
|
||||
for doc_id_processing_status in docs_batch:
|
||||
doc_id, status_doc = doc_id_processing_status
|
||||
# Generate chunks from document
|
||||
chunks: dict[str, Any] = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
status_doc.content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
self.tiktoken_model_name,
|
||||
)
|
||||
# 有文档需要处理,更新 pipeline_status
|
||||
pipeline_status.update(
|
||||
{
|
||||
"busy": True,
|
||||
"job_name": "indexing files",
|
||||
"job_start": datetime.now().isoformat(),
|
||||
"docs": 0,
|
||||
"batchs": 0,
|
||||
"cur_batch": 0,
|
||||
"request_pending": False, # Clear any previous request
|
||||
"latest_message": "",
|
||||
}
|
||||
# Process document (text chunks and full docs) in parallel
|
||||
tasks = [
|
||||
self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSING,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
}
|
||||
}
|
||||
),
|
||||
self.chunks_vdb.upsert(chunks),
|
||||
self._process_entity_relation_graph(chunks),
|
||||
self.full_docs.upsert(
|
||||
{doc_id: {"content": status_doc.content}}
|
||||
),
|
||||
self.text_chunks.upsert(chunks),
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process document {doc_id}: {str(e)}")
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
continue
|
||||
logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.")
|
||||
)
|
||||
# Cleaning history_messages without breaking it as a shared list object
|
||||
del pipeline_status["history_messages"][:]
|
||||
else:
|
||||
# Another process is busy, just set request flag and return
|
||||
pipeline_status["request_pending"] = True
|
||||
logger.info(
|
||||
"Another process is already processing the document queue. Request queued."
|
||||
)
|
||||
return
|
||||
|
||||
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
|
||||
try:
|
||||
# Process documents until no more documents or requests
|
||||
while True:
|
||||
if not to_process_docs:
|
||||
log_message = "All documents have been processed or are duplicates"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
break
|
||||
|
||||
await asyncio.gather(*batches)
|
||||
await self._insert_done()
|
||||
# 2. split docs into chunks, insert chunks, update doc status
|
||||
docs_batches = [
|
||||
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
||||
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
||||
]
|
||||
|
||||
log_message = f"Number of batches to process: {len(docs_batches)}."
|
||||
logger.info(log_message)
|
||||
|
||||
# Update pipeline status with current batch information
|
||||
pipeline_status["docs"] += len(to_process_docs)
|
||||
pipeline_status["batchs"] += len(docs_batches)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
batches: list[Any] = []
|
||||
# 3. iterate over batches
|
||||
for batch_idx, docs_batch in enumerate(docs_batches):
|
||||
# Update current batch in pipeline status (directly, as it's atomic)
|
||||
pipeline_status["cur_batch"] += 1
|
||||
|
||||
async def batch(
|
||||
batch_idx: int,
|
||||
docs_batch: list[tuple[str, DocProcessingStatus]],
|
||||
size_batch: int,
|
||||
) -> None:
|
||||
log_message = (
|
||||
f"Start processing batch {batch_idx + 1} of {size_batch}."
|
||||
)
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
# 4. iterate over batch
|
||||
for doc_id_processing_status in docs_batch:
|
||||
doc_id, status_doc = doc_id_processing_status
|
||||
# Generate chunks from document
|
||||
chunks: dict[str, Any] = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
status_doc.content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
self.tiktoken_model_name,
|
||||
)
|
||||
}
|
||||
# Process document (text chunks and full docs) in parallel
|
||||
tasks = [
|
||||
self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSING,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
}
|
||||
}
|
||||
),
|
||||
self.chunks_vdb.upsert(chunks),
|
||||
self._process_entity_relation_graph(chunks),
|
||||
self.full_docs.upsert(
|
||||
{doc_id: {"content": status_doc.content}}
|
||||
),
|
||||
self.text_chunks.upsert(chunks),
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process document {doc_id}: {str(e)}"
|
||||
)
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.FAILED,
|
||||
"error": str(e),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
)
|
||||
continue
|
||||
log_message = (
|
||||
f"Completed batch {batch_idx + 1} of {len(docs_batches)}."
|
||||
)
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
batches.append(batch(batch_idx, docs_batch, len(docs_batches)))
|
||||
|
||||
await asyncio.gather(*batches)
|
||||
await self._insert_done()
|
||||
|
||||
# Check if there's a pending request to process more documents (with lock)
|
||||
has_pending_request = False
|
||||
async with pipeline_status_lock:
|
||||
has_pending_request = pipeline_status.get("request_pending", False)
|
||||
if has_pending_request:
|
||||
# Clear the request flag before checking for more documents
|
||||
pipeline_status["request_pending"] = False
|
||||
|
||||
if not has_pending_request:
|
||||
break
|
||||
|
||||
log_message = "Processing additional documents due to pending request"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# 获取新的待处理文档
|
||||
processing_docs, failed_docs, pending_docs = await asyncio.gather(
|
||||
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
||||
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
||||
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
||||
)
|
||||
|
||||
to_process_docs = {}
|
||||
to_process_docs.update(processing_docs)
|
||||
to_process_docs.update(failed_docs)
|
||||
to_process_docs.update(pending_docs)
|
||||
|
||||
finally:
|
||||
log_message = "Document processing pipeline completed"
|
||||
logger.info(log_message)
|
||||
# Always reset busy status when done or if an exception occurs (with lock)
|
||||
async with pipeline_status_lock:
|
||||
pipeline_status["busy"] = False
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
|
||||
try:
|
||||
@@ -833,7 +941,16 @@ class LightRAG:
|
||||
if storage_inst is not None
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info("All Insert done")
|
||||
|
||||
log_message = "All Insert done"
|
||||
logger.info(log_message)
|
||||
|
||||
# 获取 pipeline_status 并更新 latest_message 和 history_messages
|
||||
from lightrag.kg.shared_storage import get_namespace_data
|
||||
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None:
|
||||
loop = always_get_an_event_loop()
|
||||
|
@@ -339,6 +339,9 @@ async def extract_entities(
|
||||
global_config: dict[str, str],
|
||||
llm_response_cache: BaseKVStorage | None = None,
|
||||
) -> None:
|
||||
from lightrag.kg.shared_storage import get_namespace_data
|
||||
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||
enable_llm_cache_for_entity_extract: bool = global_config[
|
||||
@@ -499,9 +502,10 @@ async def extract_entities(
|
||||
processed_chunks += 1
|
||||
entities_count = len(maybe_nodes)
|
||||
relations_count = len(maybe_edges)
|
||||
logger.info(
|
||||
f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
|
||||
)
|
||||
log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
tasks = [_process_single_content(c) for c in ordered_chunks]
|
||||
@@ -530,17 +534,27 @@ async def extract_entities(
|
||||
)
|
||||
|
||||
if not (all_entities_data or all_relationships_data):
|
||||
logger.info("Didn't extract any entities and relationships.")
|
||||
log_message = "Didn't extract any entities and relationships."
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
return
|
||||
|
||||
if not all_entities_data:
|
||||
logger.info("Didn't extract any entities")
|
||||
log_message = "Didn't extract any entities"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
if not all_relationships_data:
|
||||
logger.info("Didn't extract any relationships")
|
||||
log_message = "Didn't extract any relationships"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
logger.info(
|
||||
f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
|
||||
)
|
||||
log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
verbose_debug(
|
||||
f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
|
||||
)
|
||||
|
@@ -56,6 +56,18 @@ def set_verbose_debug(enabled: bool):
|
||||
VERBOSE_DEBUG = enabled
|
||||
|
||||
|
||||
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger("lightrag")
|
||||
logger.propagate = False # prevent log message send to root loggger
|
||||
# Let the main application configure the handlers
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Set httpx logging level to WARNING
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class UnlimitedSemaphore:
|
||||
"""A context manager that allows unlimited access."""
|
||||
|
||||
@@ -68,34 +80,6 @@ class UnlimitedSemaphore:
|
||||
|
||||
ENCODER = None
|
||||
|
||||
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
|
||||
|
||||
logger = logging.getLogger("lightrag")
|
||||
|
||||
# Set httpx logging level to WARNING
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def set_logger(log_file: str, level: int = logging.DEBUG):
|
||||
"""Set up file logging with the specified level.
|
||||
|
||||
Args:
|
||||
log_file: Path to the log file
|
||||
level: Logging level (e.g. logging.DEBUG, logging.INFO)
|
||||
"""
|
||||
logger.setLevel(level)
|
||||
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setLevel(level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
if not logger.handlers:
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
|
203
run_with_gunicorn.py
Executable file
203
run_with_gunicorn.py
Executable file
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Start LightRAG server with Gunicorn
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import pipmaster as pm
|
||||
from lightrag.api.utils_api import parse_args, display_splash_screen
|
||||
from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data
|
||||
|
||||
|
||||
def check_and_install_dependencies():
|
||||
"""Check and install required dependencies"""
|
||||
required_packages = [
|
||||
"gunicorn",
|
||||
"tiktoken",
|
||||
"psutil",
|
||||
# Add other required packages here
|
||||
]
|
||||
|
||||
for package in required_packages:
|
||||
if not pm.is_installed(package):
|
||||
print(f"Installing {package}...")
|
||||
pm.install(package)
|
||||
print(f"{package} installed successfully")
|
||||
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig, frame):
|
||||
print("\n\n" + "=" * 80)
|
||||
print("RECEIVED TERMINATION SIGNAL")
|
||||
print(f"Process ID: {os.getpid()}")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
# Release shared resources
|
||||
finalize_share_data()
|
||||
|
||||
# Exit with success status
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
# Check and install dependencies
|
||||
check_and_install_dependencies()
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
||||
|
||||
# Parse all arguments using parse_args
|
||||
args = parse_args(is_uvicorn_mode=False)
|
||||
|
||||
# Display startup information
|
||||
display_splash_screen(args)
|
||||
|
||||
print("🚀 Starting LightRAG with Gunicorn")
|
||||
print(f"🔄 Worker management: Gunicorn (workers={args.workers})")
|
||||
print("🔍 Preloading app: Enabled")
|
||||
print("📝 Note: Using Gunicorn's preload feature for shared data initialization")
|
||||
print("\n\n" + "=" * 80)
|
||||
print("MAIN PROCESS INITIALIZATION")
|
||||
print(f"Process ID: {os.getpid()}")
|
||||
print(f"Workers setting: {args.workers}")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
# Import Gunicorn's StandaloneApplication
|
||||
from gunicorn.app.base import BaseApplication
|
||||
|
||||
# Define a custom application class that loads our config
|
||||
class GunicornApp(BaseApplication):
|
||||
def __init__(self, app, options=None):
|
||||
self.options = options or {}
|
||||
self.application = app
|
||||
super().__init__()
|
||||
|
||||
def load_config(self):
|
||||
# Define valid Gunicorn configuration options
|
||||
valid_options = {
|
||||
"bind",
|
||||
"workers",
|
||||
"worker_class",
|
||||
"timeout",
|
||||
"keepalive",
|
||||
"preload_app",
|
||||
"errorlog",
|
||||
"accesslog",
|
||||
"loglevel",
|
||||
"certfile",
|
||||
"keyfile",
|
||||
"limit_request_line",
|
||||
"limit_request_fields",
|
||||
"limit_request_field_size",
|
||||
"graceful_timeout",
|
||||
"max_requests",
|
||||
"max_requests_jitter",
|
||||
}
|
||||
|
||||
# Special hooks that need to be set separately
|
||||
special_hooks = {
|
||||
"on_starting",
|
||||
"on_reload",
|
||||
"on_exit",
|
||||
"pre_fork",
|
||||
"post_fork",
|
||||
"pre_exec",
|
||||
"pre_request",
|
||||
"post_request",
|
||||
"worker_init",
|
||||
"worker_exit",
|
||||
"nworkers_changed",
|
||||
"child_exit",
|
||||
}
|
||||
|
||||
# Import and configure the gunicorn_config module
|
||||
import gunicorn_config
|
||||
|
||||
# Set configuration variables in gunicorn_config, prioritizing command line arguments
|
||||
gunicorn_config.workers = (
|
||||
args.workers if args.workers else int(os.getenv("WORKERS", 1))
|
||||
)
|
||||
|
||||
# Bind configuration prioritizes command line arguments
|
||||
host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0")
|
||||
port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621))
|
||||
gunicorn_config.bind = f"{host}:{port}"
|
||||
|
||||
# Log level configuration prioritizes command line arguments
|
||||
gunicorn_config.loglevel = (
|
||||
args.log_level.lower()
|
||||
if args.log_level
|
||||
else os.getenv("LOG_LEVEL", "info")
|
||||
)
|
||||
|
||||
# Timeout configuration prioritizes command line arguments
|
||||
gunicorn_config.timeout = (
|
||||
args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150))
|
||||
)
|
||||
|
||||
# Keepalive configuration
|
||||
gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5))
|
||||
|
||||
# SSL configuration prioritizes command line arguments
|
||||
if args.ssl or os.getenv("SSL", "").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
"t",
|
||||
"on",
|
||||
):
|
||||
gunicorn_config.certfile = (
|
||||
args.ssl_certfile
|
||||
if args.ssl_certfile
|
||||
else os.getenv("SSL_CERTFILE")
|
||||
)
|
||||
gunicorn_config.keyfile = (
|
||||
args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE")
|
||||
)
|
||||
|
||||
# Set configuration options from the module
|
||||
for key in dir(gunicorn_config):
|
||||
if key in valid_options:
|
||||
value = getattr(gunicorn_config, key)
|
||||
# Skip functions like on_starting and None values
|
||||
if not callable(value) and value is not None:
|
||||
self.cfg.set(key, value)
|
||||
# Set special hooks
|
||||
elif key in special_hooks:
|
||||
value = getattr(gunicorn_config, key)
|
||||
if callable(value):
|
||||
self.cfg.set(key, value)
|
||||
|
||||
if hasattr(gunicorn_config, "logconfig_dict"):
|
||||
self.cfg.set(
|
||||
"logconfig_dict", getattr(gunicorn_config, "logconfig_dict")
|
||||
)
|
||||
|
||||
def load(self):
|
||||
# Import the application
|
||||
from lightrag.api.lightrag_server import get_application
|
||||
|
||||
return get_application(args)
|
||||
|
||||
# Create the application
|
||||
app = GunicornApp("")
|
||||
|
||||
# Force workers to be an integer and greater than 1 for multi-process mode
|
||||
workers_count = int(args.workers)
|
||||
if workers_count > 1:
|
||||
# Set a flag to indicate we're in the main process
|
||||
os.environ["LIGHTRAG_MAIN_PROCESS"] = "1"
|
||||
initialize_share_data(workers_count)
|
||||
else:
|
||||
initialize_share_data(1)
|
||||
|
||||
# Run the application
|
||||
print("\nStarting Gunicorn with direct Python API...")
|
||||
app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
setup.py
1
setup.py
@@ -112,6 +112,7 @@ setuptools.setup(
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"lightrag-server=lightrag.api.lightrag_server:main [api]",
|
||||
"lightrag-gunicorn=lightrag.api.run_with_gunicorn:main [api]",
|
||||
"lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]",
|
||||
],
|
||||
},
|
||||
|
Reference in New Issue
Block a user