Add multiple workers support for API Server

This commit is contained in:
yangdx
2025-02-25 09:37:00 +08:00
parent 2e13def95c
commit d74a23d2cc
3 changed files with 147 additions and 107 deletions

View File

@@ -8,11 +8,12 @@ from fastapi import (
)
from fastapi.responses import FileResponse
import asyncio
import threading
import os
from fastapi.staticfiles import StaticFiles
import json
import logging
from typing import Dict
import logging.config
import uvicorn
from fastapi.staticfiles import StaticFiles
from pathlib import Path
import configparser
from ascii_colors import ASCIIColors
@@ -49,18 +50,6 @@ except Exception as e:
config = configparser.ConfigParser()
config.read("config.ini")
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = threading.Lock()
class AccessLogFilter(logging.Filter):
def __init__(self):
@@ -95,7 +84,6 @@ class AccessLogFilter(logging.Filter):
def create_app(args):
# Initialize verbose debug setting
from lightrag.utils import set_verbose_debug
@@ -155,25 +143,12 @@ def create_app(args):
# Auto scan documents if enabled
if args.auto_scan_at_startup:
# Start scanning in background
with progress_lock:
if not scan_progress["is_scanning"]:
scan_progress["is_scanning"] = True
scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0
# Create background task
task = asyncio.create_task(
run_scanning_process(rag, doc_manager)
)
app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard)
ASCIIColors.info(
f"Started background scanning of documents from {args.input_dir}"
)
else:
ASCIIColors.info(
"Skip document scanning(another scanning is active)"
)
# Create background task
task = asyncio.create_task(
run_scanning_process(rag, doc_manager)
)
app.state.background_tasks.add(task)
task.add_done_callback(app.state.background_tasks.discard)
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
@@ -429,48 +404,67 @@ def create_app(args):
return app
def get_application():
"""Factory function for creating the FastAPI application"""
from .utils_api import initialize_manager
initialize_manager()
# Get args from environment variable
args_json = os.environ.get('LIGHTRAG_ARGS')
if not args_json:
args = parse_args() # Fallback to parsing args if env var not set
else:
import types
args = types.SimpleNamespace(**json.loads(args_json))
return create_app(args)
def main():
from multiprocessing import freeze_support
freeze_support()
args = parse_args()
import uvicorn
import logging.config
# Save args to environment variable for child processes
os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args))
# Configure uvicorn logging
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
logging.config.dictConfig({
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"loggers": {
"uvicorn.access": {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
},
},
"loggers": {
"uvicorn.access": {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
},
}
)
},
})
# Add filter to uvicorn access logger
uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_access_logger.addFilter(AccessLogFilter())
app = create_app(args)
display_splash_screen(args)
uvicorn_config = {
"app": app,
"app": "lightrag.api.lightrag_server:get_application",
"factory": True,
"host": args.host,
"port": args.port,
"workers": args.workers,
"log_config": None, # Disable default config
}
if args.ssl:

View File

@@ -12,29 +12,23 @@ import pipmaster as pm
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from ascii_colors import ASCIIColors
from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile
from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG
from lightrag.base import DocProcessingStatus, DocStatus
from ..utils_api import get_api_key_dependency
from ..utils_api import (
get_api_key_dependency,
scan_progress,
update_scan_progress_if_not_scanning,
update_scan_progress,
reset_scan_progress,
)
router = APIRouter(prefix="/documents", tags=["documents"])
# Global progress tracker
scan_progress: Dict = {
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
}
# Lock for thread-safe operations
progress_lock = asyncio.Lock()
# Temporary file prefix
temp_prefix = "__tmp__"
@@ -167,13 +161,6 @@ class DocumentManager:
new_files.append(file_path)
return new_files
# def scan_directory(self) -> List[Path]:
# new_files = []
# for ext in self.supported_extensions:
# for file_path in self.input_dir.rglob(f"*{ext}"):
# new_files.append(file_path)
# return new_files
def mark_as_indexed(self, file_path: Path):
self.indexed_files.add(file_path)
@@ -390,24 +377,24 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path:
async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
"""Background task to scan and index documents"""
"""Background task to scan and index documents"""
if not update_scan_progress_if_not_scanning():
ASCIIColors.info(
"Skip document scanning(another scanning is active)"
)
return
try:
new_files = doc_manager.scan_directory_for_new_files()
scan_progress["total_files"] = len(new_files)
total_files = len(new_files)
update_scan_progress("", total_files, 0) # Initialize progress
logging.info(f"Found {len(new_files)} new files to index.")
for file_path in new_files:
logging.info(f"Found {total_files} new files to index.")
for idx, file_path in enumerate(new_files):
try:
async with progress_lock:
scan_progress["current_file"] = os.path.basename(file_path)
update_scan_progress(os.path.basename(file_path), total_files, idx)
await pipeline_index_file(rag, file_path)
async with progress_lock:
scan_progress["indexed_count"] += 1
scan_progress["progress"] = (
scan_progress["indexed_count"] / scan_progress["total_files"]
) * 100
update_scan_progress(os.path.basename(file_path), total_files, idx + 1)
except Exception as e:
logging.error(f"Error indexing file {file_path}: {str(e)}")
@@ -415,8 +402,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager):
except Exception as e:
logging.error(f"Error during scanning process: {str(e)}")
finally:
async with progress_lock:
scan_progress["is_scanning"] = False
reset_scan_progress()
def create_document_routes(
@@ -436,14 +422,6 @@ def create_document_routes(
Returns:
dict: A dictionary containing the scanning status
"""
async with progress_lock:
if scan_progress["is_scanning"]:
return {"status": "already_scanning"}
scan_progress["is_scanning"] = True
scan_progress["indexed_count"] = 0
scan_progress["progress"] = 0
# Start the scanning process in the background
background_tasks.add_task(run_scanning_process, rag, doc_manager)
return {"status": "scanning_started"}
@@ -461,8 +439,7 @@ def create_document_routes(
- total_files: Total number of files to process
- progress: Percentage of completion
"""
async with progress_lock:
return scan_progress
return dict(scan_progress)
@router.post("/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(

View File

@@ -6,6 +6,7 @@ import os
import argparse
from typing import Optional
import sys
from multiprocessing import Manager
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__
from fastapi import HTTPException, Security
@@ -16,6 +17,66 @@ from starlette.status import HTTP_403_FORBIDDEN
# Load environment variables
load_dotenv(override=True)
# Global variables for manager and shared state
manager = None
scan_progress = None
scan_lock = None
def initialize_manager():
"""Initialize manager and shared state for cross-process communication"""
global manager, scan_progress, scan_lock
if manager is None:
manager = Manager()
scan_progress = manager.dict({
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
scan_lock = manager.Lock()
def update_scan_progress_if_not_scanning():
"""
Atomically check if scanning is not in progress and update scan_progress if it's not.
Returns True if the update was successful, False if scanning was already in progress.
"""
with scan_lock:
if not scan_progress["is_scanning"]:
scan_progress.update({
"is_scanning": True,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
return True
return False
def update_scan_progress(current_file: str, total_files: int, indexed_count: int):
"""
Atomically update scan progress information.
"""
progress = (indexed_count / total_files * 100) if total_files > 0 else 0
scan_progress.update({
"current_file": current_file,
"indexed_count": indexed_count,
"total_files": total_files,
"progress": progress,
})
def reset_scan_progress():
"""
Atomically reset scan progress to initial state.
"""
scan_progress.update({
"is_scanning": False,
"current_file": "",
"indexed_count": 0,
"total_files": 0,
"progress": 0,
})
class OllamaServerInfos:
# Constants for emulated Ollama model information
@@ -260,6 +321,14 @@ def parse_args() -> argparse.Namespace:
help="Enable automatic scanning when the program starts",
)
# Server workers configuration
parser.add_argument(
"--workers",
type=int,
default=get_env_value("WORKERS", 2, int),
help="Number of worker processes (default: from env or 2)",
)
# LLM and embedding bindings
parser.add_argument(
"--llm-binding",