From 2e13def95c93a4fb95aeed38c092ac9d43ef5d34 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 24 Feb 2025 18:20:39 +0800 Subject: [PATCH 01/77] Remove unused global_top_k variable and related configurations. --- lightrag/api/lightrag_server.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 9b2a1c76..0da555ed 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -49,9 +49,6 @@ except Exception as e: config = configparser.ConfigParser() config.read("config.ini") -# Global configuration -global_top_k = 60 # default value - # Global progress tracker scan_progress: Dict = { "is_scanning": False, @@ -98,9 +95,6 @@ class AccessLogFilter(logging.Filter): def create_app(args): - # Set global top_k - global global_top_k - global_top_k = args.top_k # save top_k from args # Initialize verbose debug setting from lightrag.utils import set_verbose_debug From d74a23d2cce86a2d697852b0849898ee00373ed3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 09:37:00 +0800 Subject: [PATCH 02/77] Add multiple workers support for API Server --- lightrag/api/lightrag_server.py | 116 +++++++++++------------- lightrag/api/routers/document_routes.py | 69 +++++--------- lightrag/api/utils_api.py | 69 ++++++++++++++ 3 files changed, 147 insertions(+), 107 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 0da555ed..62cb24db 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -8,11 +8,12 @@ from fastapi import ( ) from fastapi.responses import FileResponse import asyncio -import threading import os -from fastapi.staticfiles import StaticFiles +import json import logging -from typing import Dict +import logging.config +import uvicorn +from fastapi.staticfiles import StaticFiles from pathlib import Path import configparser from ascii_colors import ASCIIColors @@ -49,18 +50,6 @@ except Exception as e: config = configparser.ConfigParser() config.read("config.ini") -# Global progress tracker -scan_progress: Dict = { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, -} - -# Lock for thread-safe operations -progress_lock = threading.Lock() - class AccessLogFilter(logging.Filter): def __init__(self): @@ -95,7 +84,6 @@ class AccessLogFilter(logging.Filter): def create_app(args): - # Initialize verbose debug setting from lightrag.utils import set_verbose_debug @@ -155,25 +143,12 @@ def create_app(args): # Auto scan documents if enabled if args.auto_scan_at_startup: - # Start scanning in background - with progress_lock: - if not scan_progress["is_scanning"]: - scan_progress["is_scanning"] = True - scan_progress["indexed_count"] = 0 - scan_progress["progress"] = 0 - # Create background task - task = asyncio.create_task( - run_scanning_process(rag, doc_manager) - ) - app.state.background_tasks.add(task) - task.add_done_callback(app.state.background_tasks.discard) - ASCIIColors.info( - f"Started background scanning of documents from {args.input_dir}" - ) - else: - ASCIIColors.info( - "Skip document scanning(another scanning is active)" - ) + # Create background task + task = asyncio.create_task( + run_scanning_process(rag, doc_manager) + ) + app.state.background_tasks.add(task) + task.add_done_callback(app.state.background_tasks.discard) ASCIIColors.green("\nServer is ready to accept connections! ๐Ÿš€\n") @@ -429,48 +404,67 @@ def create_app(args): return app +def get_application(): + """Factory function for creating the FastAPI application""" + from .utils_api import initialize_manager + initialize_manager() + + # Get args from environment variable + args_json = os.environ.get('LIGHTRAG_ARGS') + if not args_json: + args = parse_args() # Fallback to parsing args if env var not set + else: + import types + args = types.SimpleNamespace(**json.loads(args_json)) + + return create_app(args) + + def main(): + from multiprocessing import freeze_support + freeze_support() + args = parse_args() - import uvicorn - import logging.config + # Save args to environment variable for child processes + os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) # Configure uvicorn logging - logging.config.dictConfig( - { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "default": { - "format": "%(levelname)s: %(message)s", - }, + logging.config.dictConfig({ + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", }, - "handlers": { - "default": { - "formatter": "default", - "class": "logging.StreamHandler", - "stream": "ext://sys.stderr", - }, + }, + "handlers": { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", }, - "loggers": { - "uvicorn.access": { - "handlers": ["default"], - "level": "INFO", - "propagate": False, - }, + }, + "loggers": { + "uvicorn.access": { + "handlers": ["default"], + "level": "INFO", + "propagate": False, }, - } - ) + }, + }) # Add filter to uvicorn access logger uvicorn_access_logger = logging.getLogger("uvicorn.access") uvicorn_access_logger.addFilter(AccessLogFilter()) - app = create_app(args) display_splash_screen(args) + uvicorn_config = { - "app": app, + "app": "lightrag.api.lightrag_server:get_application", + "factory": True, "host": args.host, "port": args.port, + "workers": args.workers, "log_config": None, # Disable default config } if args.ssl: diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 5c742f39..ea6bf29d 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -12,29 +12,23 @@ import pipmaster as pm from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Any - +from ascii_colors import ASCIIColors from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus -from ..utils_api import get_api_key_dependency +from ..utils_api import ( + get_api_key_dependency, + scan_progress, + update_scan_progress_if_not_scanning, + update_scan_progress, + reset_scan_progress, +) router = APIRouter(prefix="/documents", tags=["documents"]) -# Global progress tracker -scan_progress: Dict = { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, -} - -# Lock for thread-safe operations -progress_lock = asyncio.Lock() - # Temporary file prefix temp_prefix = "__tmp__" @@ -167,13 +161,6 @@ class DocumentManager: new_files.append(file_path) return new_files - # def scan_directory(self) -> List[Path]: - # new_files = [] - # for ext in self.supported_extensions: - # for file_path in self.input_dir.rglob(f"*{ext}"): - # new_files.append(file_path) - # return new_files - def mark_as_indexed(self, file_path: Path): self.indexed_files.add(file_path) @@ -390,24 +377,24 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): - """Background task to scan and index documents""" + """Background task to scan and index documents""" + if not update_scan_progress_if_not_scanning(): + ASCIIColors.info( + "Skip document scanning(another scanning is active)" + ) + return + try: new_files = doc_manager.scan_directory_for_new_files() - scan_progress["total_files"] = len(new_files) + total_files = len(new_files) + update_scan_progress("", total_files, 0) # Initialize progress - logging.info(f"Found {len(new_files)} new files to index.") - for file_path in new_files: + logging.info(f"Found {total_files} new files to index.") + for idx, file_path in enumerate(new_files): try: - async with progress_lock: - scan_progress["current_file"] = os.path.basename(file_path) - + update_scan_progress(os.path.basename(file_path), total_files, idx) await pipeline_index_file(rag, file_path) - - async with progress_lock: - scan_progress["indexed_count"] += 1 - scan_progress["progress"] = ( - scan_progress["indexed_count"] / scan_progress["total_files"] - ) * 100 + update_scan_progress(os.path.basename(file_path), total_files, idx + 1) except Exception as e: logging.error(f"Error indexing file {file_path}: {str(e)}") @@ -415,8 +402,7 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): except Exception as e: logging.error(f"Error during scanning process: {str(e)}") finally: - async with progress_lock: - scan_progress["is_scanning"] = False + reset_scan_progress() def create_document_routes( @@ -436,14 +422,6 @@ def create_document_routes( Returns: dict: A dictionary containing the scanning status """ - async with progress_lock: - if scan_progress["is_scanning"]: - return {"status": "already_scanning"} - - scan_progress["is_scanning"] = True - scan_progress["indexed_count"] = 0 - scan_progress["progress"] = 0 - # Start the scanning process in the background background_tasks.add_task(run_scanning_process, rag, doc_manager) return {"status": "scanning_started"} @@ -461,8 +439,7 @@ def create_document_routes( - total_files: Total number of files to process - progress: Percentage of completion """ - async with progress_lock: - return scan_progress + return dict(scan_progress) @router.post("/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir( diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 17f19627..da8d84fa 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -6,6 +6,7 @@ import os import argparse from typing import Optional import sys +from multiprocessing import Manager from ascii_colors import ASCIIColors from lightrag.api import __api_version__ from fastapi import HTTPException, Security @@ -16,6 +17,66 @@ from starlette.status import HTTP_403_FORBIDDEN # Load environment variables load_dotenv(override=True) +# Global variables for manager and shared state +manager = None +scan_progress = None +scan_lock = None + +def initialize_manager(): + """Initialize manager and shared state for cross-process communication""" + global manager, scan_progress, scan_lock + if manager is None: + manager = Manager() + scan_progress = manager.dict({ + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + }) + scan_lock = manager.Lock() + +def update_scan_progress_if_not_scanning(): + """ + Atomically check if scanning is not in progress and update scan_progress if it's not. + Returns True if the update was successful, False if scanning was already in progress. + """ + with scan_lock: + if not scan_progress["is_scanning"]: + scan_progress.update({ + "is_scanning": True, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + }) + return True + return False + +def update_scan_progress(current_file: str, total_files: int, indexed_count: int): + """ + Atomically update scan progress information. + """ + progress = (indexed_count / total_files * 100) if total_files > 0 else 0 + scan_progress.update({ + "current_file": current_file, + "indexed_count": indexed_count, + "total_files": total_files, + "progress": progress, + }) + +def reset_scan_progress(): + """ + Atomically reset scan progress to initial state. + """ + scan_progress.update({ + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + }) + class OllamaServerInfos: # Constants for emulated Ollama model information @@ -260,6 +321,14 @@ def parse_args() -> argparse.Namespace: help="Enable automatic scanning when the program starts", ) + # Server workers configuration + parser.add_argument( + "--workers", + type=int, + default=get_env_value("WORKERS", 2, int), + help="Number of worker processes (default: from env or 2)", + ) + # LLM and embedding bindings parser.add_argument( "--llm-binding", From ddc366b672b06eb4b0127782a150bcbbe25bd002 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 09:44:17 +0800 Subject: [PATCH 03/77] Optimize display_splash_screen function - Merge System Configuration into Server Configuration section - Add Workers parameter display after Port parameter --- lightrag/api/utils_api.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index da8d84fa..e0a783a2 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -415,17 +415,27 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.host}") ASCIIColors.white(" โ”œโ”€ Port: ", end="") ASCIIColors.yellow(f"{args.port}") + ASCIIColors.white(" โ”œโ”€ Workers: ", end="") + ASCIIColors.yellow(f"{args.workers}") ASCIIColors.white(" โ”œโ”€ CORS Origins: ", end="") ASCIIColors.yellow(f"{os.getenv('CORS_ORIGINS', '*')}") ASCIIColors.white(" โ”œโ”€ SSL Enabled: ", end="") ASCIIColors.yellow(f"{args.ssl}") - ASCIIColors.white(" โ””โ”€ API Key: ", end="") - ASCIIColors.yellow("Set" if args.key else "Not Set") if args.ssl: ASCIIColors.white(" โ”œโ”€ SSL Cert: ", end="") ASCIIColors.yellow(f"{args.ssl_certfile}") - ASCIIColors.white(" โ””โ”€ SSL Key: ", end="") + ASCIIColors.white(" โ”œโ”€ SSL Key: ", end="") ASCIIColors.yellow(f"{args.ssl_keyfile}") + ASCIIColors.white(" โ”œโ”€ Ollama Emulating Model: ", end="") + ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") + ASCIIColors.white(" โ”œโ”€ Log Level: ", end="") + ASCIIColors.yellow(f"{args.log_level}") + ASCIIColors.white(" โ”œโ”€ Verbose Debug: ", end="") + ASCIIColors.yellow(f"{args.verbose}") + ASCIIColors.white(" โ”œโ”€ Timeout: ", end="") + ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") + ASCIIColors.white(" โ””โ”€ API Key: ", end="") + ASCIIColors.yellow("Set" if args.key else "Not Set") # Directory Configuration ASCIIColors.magenta("\n๐Ÿ“‚ Directory Configuration:") @@ -484,15 +494,6 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.white(" โ””โ”€ Document Status Storage: ", end="") ASCIIColors.yellow(f"{args.doc_status_storage}") - ASCIIColors.magenta("\n๐Ÿ› ๏ธ System Configuration:") - ASCIIColors.white(" โ”œโ”€ Ollama Emulating Model: ", end="") - ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") - ASCIIColors.white(" โ”œโ”€ Log Level: ", end="") - ASCIIColors.yellow(f"{args.log_level}") - ASCIIColors.white(" โ”œโ”€ Verbose Debug: ", end="") - ASCIIColors.yellow(f"{args.verbose}") - ASCIIColors.white(" โ””โ”€ Timeout: ", end="") - ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") # Server Status ASCIIColors.green("\nโœจ Server starting up...\n") From 04fc5ce6041dc4731ff6faf4c09e4ed1cc91a204 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 09:45:14 +0800 Subject: [PATCH 04/77] Remove unspported endpoint from splash mesages --- lightrag/api/utils_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index e0a783a2..0d8bebd0 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -548,7 +548,6 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.cyan(""" 3. Basic Operations: - POST /upload_document: Upload new documents to RAG - POST /query: Query your document collection - - GET /collections: List available collections 4. Monitor the server: - Check server logs for detailed operation information From 7262f61b0ee434243f29be71255c4de42bd79639 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 10:47:27 +0800 Subject: [PATCH 05/77] add redis configuration and update workers default value --- .env.example | 4 ++++ lightrag/api/utils_api.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index e4034def..0f8e6c31 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,7 @@ ### Server Configuration # HOST=0.0.0.0 # PORT=9621 +# WORKERS=1 # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances # CORS_ORIGINS=http://localhost:3000,http://localhost:8080 @@ -138,3 +139,6 @@ MONGODB_GRAPH=false # deprecated (keep for backward compatibility) ### Qdrant QDRANT_URL=http://localhost:16333 # QDRANT_API_KEY=your-api-key + +### Redis +REDIS_URI=redis://localhost:6379 \ No newline at end of file diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 0d8bebd0..2544276a 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -325,8 +325,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--workers", type=int, - default=get_env_value("WORKERS", 2, int), - help="Number of worker processes (default: from env or 2)", + default=get_env_value("WORKERS", 1, int), + help="Number of worker processes (default: from env or 1)", ) # LLM and embedding bindings From 087d5770b028da1eb844ddfbefaf9b90bd24410e Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 11:10:13 +0800 Subject: [PATCH 06/77] feat(storage): Add shared memory support for file-based storage implementations This commit adds multiprocessing shared memory support to file-based storage implementations: - JsonDocStatusStorage - JsonKVStorage - NanoVectorDBStorage - NetworkXStorage Each storage module now uses module-level global variables with multiprocessing.Manager() to ensure data consistency across multiple uvicorn workers. All processes will see updates immediately when data is modified through ainsert function. --- lightrag/kg/json_doc_status_impl.py | 44 +++++++++++++++++++++- lightrag/kg/json_kv_impl.py | 44 +++++++++++++++++++++- lightrag/kg/nano_vector_db_impl.py | 47 +++++++++++++++++++++-- lightrag/kg/networkx_impl.py | 58 ++++++++++++++++++++++++----- 4 files changed, 176 insertions(+), 17 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 63a295cd..431e340c 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -1,6 +1,8 @@ from dataclasses import dataclass import os from typing import Any, Union, final +import threading +from multiprocessing import Manager from lightrag.base import ( DocProcessingStatus, @@ -13,6 +15,25 @@ from lightrag.utils import ( write_json, ) +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_doc_status_data = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_doc_status_data + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_doc_status_data = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -22,8 +43,27 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data: dict[str, Any] = load_json(self._file_name) or {} - logger.info(f"Loaded document status storage with {len(self._data)} records") + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace data + if self.namespace not in _shared_doc_status_data: + with _init_lock: + if self.namespace not in _shared_doc_status_data: + try: + initial_data = load_json(self._file_name) or {} + _shared_doc_status_data[self.namespace] = initial_data + except Exception as e: + logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}") + raise RuntimeError(f"Shared data initialization failed: {e}") + + try: + self._data = _shared_doc_status_data[self.namespace] + logger.info(f"Loaded document status storage with {len(self._data)} records") + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index e1ea507a..f03fda63 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -2,6 +2,8 @@ import asyncio import os from dataclasses import dataclass from typing import Any, final +import threading +from multiprocessing import Manager from lightrag.base import ( BaseKVStorage, @@ -12,6 +14,25 @@ from lightrag.utils import ( write_json, ) +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_kv_data = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_kv_data + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_kv_data = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -19,9 +40,28 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data: dict[str, Any] = load_json(self._file_name) or {} self._lock = asyncio.Lock() - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace data + if self.namespace not in _shared_kv_data: + with _init_lock: + if self.namespace not in _shared_kv_data: + try: + initial_data = load_json(self._file_name) or {} + _shared_kv_data[self.namespace] = initial_data + except Exception as e: + logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}") + raise RuntimeError(f"Shared data initialization failed: {e}") + + try: + self._data = _shared_kv_data[self.namespace] + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def index_done_callback(self) -> None: write_json(self._data, self._file_name) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index b0900095..d68b7f42 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -3,6 +3,8 @@ import os from typing import Any, final from dataclasses import dataclass import numpy as np +import threading +from multiprocessing import Manager import time @@ -20,6 +22,25 @@ if not pm.is_installed("nano-vectordb"): from nano_vectordb import NanoVectorDB +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_vector_clients = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_vector_clients + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_vector_clients = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -40,9 +61,29 @@ class NanoVectorDBStorage(BaseVectorStorage): self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name - ) + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace client + if self.namespace not in _shared_vector_clients: + with _init_lock: + if self.namespace not in _shared_vector_clients: + try: + client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name + ) + _shared_vector_clients[self.namespace] = client + except Exception as e: + logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}") + raise RuntimeError(f"Vector DB client initialization failed: {e}") + + try: + self._client = _shared_vector_clients[self.namespace] + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index b4321458..581a4187 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -1,10 +1,11 @@ import os from dataclasses import dataclass from typing import Any, final +import threading +from multiprocessing import Manager import numpy as np - from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import ( logger, @@ -24,6 +25,25 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_graphs = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_graphs + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_graphs = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -78,15 +98,33 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - self._graph = preloaded_graph or nx.Graph() - self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, - } + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace graph + if self.namespace not in _shared_graphs: + with _init_lock: + if self.namespace not in _shared_graphs: + try: + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + if preloaded_graph is not None: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + _shared_graphs[self.namespace] = preloaded_graph or nx.Graph() + except Exception as e: + logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}") + raise RuntimeError(f"Graph initialization failed: {e}") + + try: + self._graph = _shared_graphs[self.namespace] + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def index_done_callback(self) -> None: NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) From e22e014f228dc032b78dae2a375bd89c6607117f Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 11:25:06 +0800 Subject: [PATCH 07/77] feat(storage): Add shared memory support for FAISS --- lightrag/kg/faiss_impl.py | 77 +++++++++++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 2ac0899e..4324e965 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -2,6 +2,8 @@ import os import time import asyncio from typing import Any, final +import threading +from multiprocessing import Manager import json import numpy as np @@ -22,6 +24,27 @@ if not pm.is_installed("faiss"): import faiss +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_indices = None +_shared_meta = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_indices, _shared_meta + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_indices = _manager.dict() + _shared_meta = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -50,18 +73,48 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim - - # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). - # If you have a large number of vectors, you might want IVF or other indexes. - # For demonstration, we use a simple IndexFlatIP. - self._index = faiss.IndexFlatIP(self._dim) - - # Keep a local store for metadata, IDs, etc. - # Maps โ†’ metadata (including your original ID). - self._id_to_meta = {} - - # Attempt to load an existing index + metadata from disk - self._load_faiss_index() + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace index and metadata + if self.namespace not in _shared_indices: + with _init_lock: + if self.namespace not in _shared_indices: + try: + # Create an empty Faiss index for inner product + index = faiss.IndexFlatIP(self._dim) + meta = {} + + # Load existing index if available + if os.path.exists(self._faiss_index_file): + try: + index = faiss.read_index(self._faiss_index_file) + with open(self._meta_file, "r", encoding="utf-8") as f: + stored_dict = json.load(f) + # Convert string keys back to int + meta = {int(k): v for k, v in stored_dict.items()} + logger.info( + f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}" + ) + except Exception as e: + logger.error(f"Failed to load Faiss index or metadata: {e}") + logger.warning("Starting with an empty Faiss index.") + index = faiss.IndexFlatIP(self._dim) + meta = {} + + _shared_indices[self.namespace] = index + _shared_meta[self.namespace] = meta + except Exception as e: + logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}") + raise RuntimeError(f"Faiss index initialization failed: {e}") + + try: + self._index = _shared_indices[self.namespace] + self._id_to_meta = _shared_meta[self.namespace] + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ From a642bb31904624b4317c2a15a2ee953b33825d01 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 12:08:49 +0800 Subject: [PATCH 08/77] refactor: use shared manager from main process for storage implementations. --- lightrag/kg/faiss_impl.py | 7 +++---- lightrag/kg/json_doc_status_impl.py | 4 ++-- lightrag/kg/json_kv_impl.py | 4 ++-- lightrag/kg/nano_vector_db_impl.py | 5 ++--- lightrag/kg/networkx_impl.py | 5 ++--- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 4324e965..2e129472 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -3,13 +3,12 @@ import time import asyncio from typing import Any, final import threading -from multiprocessing import Manager - import json import numpy as np from dataclasses import dataclass import pipmaster as pm +from lightrag.api.utils_api import manager as main_process_manager from lightrag.utils import ( logger, @@ -22,7 +21,7 @@ from lightrag.base import ( if not pm.is_installed("faiss"): pm.install("faiss") -import faiss +import faiss # type: ignore # Global variables for shared memory management _init_lock = threading.Lock() @@ -37,7 +36,7 @@ def _get_manager(): with _init_lock: if _manager is None: try: - _manager = Manager() + _manager = main_process_manager _shared_indices = _manager.dict() _shared_meta = _manager.dict() except Exception as e: diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 431e340c..dd3a7b64 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -2,7 +2,6 @@ from dataclasses import dataclass import os from typing import Any, Union, final import threading -from multiprocessing import Manager from lightrag.base import ( DocProcessingStatus, @@ -14,6 +13,7 @@ from lightrag.utils import ( logger, write_json, ) +from lightrag.api.utils_api import manager as main_process_manager # Global variables for shared memory management _init_lock = threading.Lock() @@ -27,7 +27,7 @@ def _get_manager(): with _init_lock: if _manager is None: try: - _manager = Manager() + _manager = main_process_manager _shared_doc_status_data = _manager.dict() except Exception as e: logger.error(f"Failed to initialize shared memory manager: {e}") diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index f03fda63..f5a8b488 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -3,7 +3,6 @@ import os from dataclasses import dataclass from typing import Any, final import threading -from multiprocessing import Manager from lightrag.base import ( BaseKVStorage, @@ -13,6 +12,7 @@ from lightrag.utils import ( logger, write_json, ) +from lightrag.api.utils_api import manager as main_process_manager # Global variables for shared memory management _init_lock = threading.Lock() @@ -26,7 +26,7 @@ def _get_manager(): with _init_lock: if _manager is None: try: - _manager = Manager() + _manager = main_process_manager _shared_kv_data = _manager.dict() except Exception as e: logger.error(f"Failed to initialize shared memory manager: {e}") diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index d68b7f42..7c15142e 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -4,14 +4,13 @@ from typing import Any, final from dataclasses import dataclass import numpy as np import threading -from multiprocessing import Manager - import time from lightrag.utils import ( logger, compute_mdhash_id, ) +from lightrag.api.utils_api import manager as main_process_manager import pipmaster as pm from lightrag.base import ( BaseVectorStorage, @@ -34,7 +33,7 @@ def _get_manager(): with _init_lock: if _manager is None: try: - _manager = Manager() + _manager = main_process_manager _shared_vector_clients = _manager.dict() except Exception as e: logger.error(f"Failed to initialize shared memory manager: {e}") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index c88d1c59..f3dd92dc 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -2,14 +2,13 @@ import os from dataclasses import dataclass from typing import Any, final import threading -from multiprocessing import Manager - import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import ( logger, ) +from lightrag.api.utils_api import manager as main_process_manager from lightrag.base import ( BaseGraphStorage, @@ -37,7 +36,7 @@ def _get_manager(): with _init_lock: if _manager is None: try: - _manager = Manager() + _manager = main_process_manager _shared_graphs = _manager.dict() except Exception as e: logger.error(f"Failed to initialize shared memory manager: {e}") From 8050b0f91b2fabf79d28131af5b016607a37c84a Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Feb 2025 12:09:30 +0800 Subject: [PATCH 09/77] feat: automatically initialize API manager in single process mode - Add manager init check in __post_init__ - Call initialize_manager if needed - Add info log message for init - Ensure API manager is ready for use --- lightrag/lightrag.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 46638243..c115b33a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -267,6 +267,12 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): + # Initialize manager if needed + from lightrag.api.utils_api import manager, initialize_manager + if manager is None: + initialize_manager() + logger.info("Initialized manager for single process mode") + os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") From 2752a764ae39acb824cf519caaeabae689729a6b Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 05:38:38 +0800 Subject: [PATCH 10/77] Refactor storage implementations to support both single and multi-process modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add shared storage management module โ€ข Support process/thread lock based on mode --- lightrag/api/lightrag_server.py | 9 +- lightrag/api/routers/document_routes.py | 66 +++-- lightrag/api/utils_api.py | 61 ----- lightrag/kg/faiss_impl.py | 309 +++++++++++------------ lightrag/kg/json_doc_status_impl.py | 108 +++----- lightrag/kg/json_kv_impl.py | 91 +++---- lightrag/kg/nano_vector_db_impl.py | 165 ++++++------ lightrag/kg/networkx_impl.py | 320 ++++++++++++------------ lightrag/kg/shared_storage.py | 94 +++++++ lightrag/lightrag.py | 8 +- 10 files changed, 608 insertions(+), 623 deletions(-) create mode 100644 lightrag/kg/shared_storage.py diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 62cb24db..65227e97 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -406,9 +406,6 @@ def create_app(args): def get_application(): """Factory function for creating the FastAPI application""" - from .utils_api import initialize_manager - initialize_manager() - # Get args from environment variable args_json = os.environ.get('LIGHTRAG_ARGS') if not args_json: @@ -428,6 +425,12 @@ def main(): # Save args to environment variable for child processes os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) + if args.workers > 1: + from lightrag.kg.shared_storage import initialize_manager + initialize_manager() + import lightrag.kg.shared_storage as shared_storage + shared_storage.is_multiprocess = True + # Configure uvicorn logging logging.config.dictConfig({ "version": 1, diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index ea6bf29d..c084023d 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -18,12 +18,10 @@ from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus -from ..utils_api import ( - get_api_key_dependency, - scan_progress, - update_scan_progress_if_not_scanning, - update_scan_progress, - reset_scan_progress, +from ..utils_api import get_api_key_dependency +from lightrag.kg.shared_storage import ( + get_scan_progress, + get_scan_lock, ) @@ -378,23 +376,51 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): """Background task to scan and index documents""" - if not update_scan_progress_if_not_scanning(): - ASCIIColors.info( - "Skip document scanning(another scanning is active)" - ) - return + scan_progress = get_scan_progress() + scan_lock = get_scan_lock() + + with scan_lock: + if scan_progress["is_scanning"]: + ASCIIColors.info( + "Skip document scanning(another scanning is active)" + ) + return + scan_progress.update({ + "is_scanning": True, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + }) try: new_files = doc_manager.scan_directory_for_new_files() total_files = len(new_files) - update_scan_progress("", total_files, 0) # Initialize progress + scan_progress.update({ + "current_file": "", + "total_files": total_files, + "indexed_count": 0, + "progress": 0, + }) logging.info(f"Found {total_files} new files to index.") for idx, file_path in enumerate(new_files): try: - update_scan_progress(os.path.basename(file_path), total_files, idx) + progress = (idx / total_files * 100) if total_files > 0 else 0 + scan_progress.update({ + "current_file": os.path.basename(file_path), + "indexed_count": idx, + "progress": progress, + }) + await pipeline_index_file(rag, file_path) - update_scan_progress(os.path.basename(file_path), total_files, idx + 1) + + progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0 + scan_progress.update({ + "current_file": os.path.basename(file_path), + "indexed_count": idx + 1, + "progress": progress, + }) except Exception as e: logging.error(f"Error indexing file {file_path}: {str(e)}") @@ -402,7 +428,13 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): except Exception as e: logging.error(f"Error during scanning process: {str(e)}") finally: - reset_scan_progress() + scan_progress.update({ + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + }) def create_document_routes( @@ -427,7 +459,7 @@ def create_document_routes( return {"status": "scanning_started"} @router.get("/scan-progress") - async def get_scan_progress(): + async def get_scanning_progress(): """ Get the current progress of the document scanning process. @@ -439,7 +471,7 @@ def create_document_routes( - total_files: Total number of files to process - progress: Percentage of completion """ - return dict(scan_progress) + return dict(get_scan_progress()) @router.post("/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir( diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 2544276a..6b501e64 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -6,7 +6,6 @@ import os import argparse from typing import Optional import sys -from multiprocessing import Manager from ascii_colors import ASCIIColors from lightrag.api import __api_version__ from fastapi import HTTPException, Security @@ -17,66 +16,6 @@ from starlette.status import HTTP_403_FORBIDDEN # Load environment variables load_dotenv(override=True) -# Global variables for manager and shared state -manager = None -scan_progress = None -scan_lock = None - -def initialize_manager(): - """Initialize manager and shared state for cross-process communication""" - global manager, scan_progress, scan_lock - if manager is None: - manager = Manager() - scan_progress = manager.dict({ - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) - scan_lock = manager.Lock() - -def update_scan_progress_if_not_scanning(): - """ - Atomically check if scanning is not in progress and update scan_progress if it's not. - Returns True if the update was successful, False if scanning was already in progress. - """ - with scan_lock: - if not scan_progress["is_scanning"]: - scan_progress.update({ - "is_scanning": True, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) - return True - return False - -def update_scan_progress(current_file: str, total_files: int, indexed_count: int): - """ - Atomically update scan progress information. - """ - progress = (indexed_count / total_files * 100) if total_files > 0 else 0 - scan_progress.update({ - "current_file": current_file, - "indexed_count": indexed_count, - "total_files": total_files, - "progress": progress, - }) - -def reset_scan_progress(): - """ - Atomically reset scan progress to initial state. - """ - scan_progress.update({ - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) - class OllamaServerInfos: # Constants for emulated Ollama model information diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 2e129472..8c9c52c4 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -2,48 +2,21 @@ import os import time import asyncio from typing import Any, final -import threading import json import numpy as np from dataclasses import dataclass import pipmaster as pm -from lightrag.api.utils_api import manager as main_process_manager -from lightrag.utils import ( - logger, - compute_mdhash_id, -) -from lightrag.base import ( - BaseVectorStorage, -) +from lightrag.utils import logger,compute_mdhash_id +from lightrag.base import BaseVectorStorage +from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess if not pm.is_installed("faiss"): pm.install("faiss") import faiss # type: ignore -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_indices = None -_shared_meta = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_indices, _shared_meta - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_indices = _manager.dict() - _shared_meta = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager - @final @dataclass @@ -72,48 +45,29 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim + self._storage_lock = get_storage_lock() - # Ensure manager is initialized - _get_manager() + self._index = get_namespace_object('faiss_indices') + self._id_to_meta = get_namespace_data('faiss_meta') - # Get or create namespace index and metadata - if self.namespace not in _shared_indices: - with _init_lock: - if self.namespace not in _shared_indices: - try: - # Create an empty Faiss index for inner product - index = faiss.IndexFlatIP(self._dim) - meta = {} - - # Load existing index if available - if os.path.exists(self._faiss_index_file): - try: - index = faiss.read_index(self._faiss_index_file) - with open(self._meta_file, "r", encoding="utf-8") as f: - stored_dict = json.load(f) - # Convert string keys back to int - meta = {int(k): v for k, v in stored_dict.items()} - logger.info( - f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}" - ) - except Exception as e: - logger.error(f"Failed to load Faiss index or metadata: {e}") - logger.warning("Starting with an empty Faiss index.") - index = faiss.IndexFlatIP(self._dim) - meta = {} - - _shared_indices[self.namespace] = index - _shared_meta[self.namespace] = meta - except Exception as e: - logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}") - raise RuntimeError(f"Faiss index initialization failed: {e}") - - try: - self._index = _shared_indices[self.namespace] - self._id_to_meta = _shared_meta[self.namespace] - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + with self._storage_lock: + if is_multiprocess: + if self._index.value is None: + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). + # If you have a large number of vectors, you might want IVF or other indexes. + # For demonstration, we use a simple IndexFlatIP. + self._index.value = faiss.IndexFlatIP(self._dim) + else: + if self._index is None: + self._index = faiss.IndexFlatIP(self._dim) + + # Keep a local store for metadata, IDs, etc. + # Maps โ†’ metadata (including your original ID). + self._id_to_meta.update({}) + + # Attempt to load an existing index + metadata from disk + self._load_faiss_index() + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ @@ -168,32 +122,36 @@ class FaissVectorDBStorage(BaseVectorStorage): # Normalize embeddings for cosine similarity (in-place) faiss.normalize_L2(embeddings) - # Upsert logic: - # 1. Identify which vectors to remove if they exist - # 2. Remove them - # 3. Add the new vectors - existing_ids_to_remove = [] - for meta, emb in zip(list_data, embeddings): - faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) - if faiss_internal_id is not None: - existing_ids_to_remove.append(faiss_internal_id) + with self._storage_lock: + # Upsert logic: + # 1. Identify which vectors to remove if they exist + # 2. Remove them + # 3. Add the new vectors + existing_ids_to_remove = [] + for meta, emb in zip(list_data, embeddings): + faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) + if faiss_internal_id is not None: + existing_ids_to_remove.append(faiss_internal_id) - if existing_ids_to_remove: - self._remove_faiss_ids(existing_ids_to_remove) + if existing_ids_to_remove: + self._remove_faiss_ids(existing_ids_to_remove) - # Step 2: Add new vectors - start_idx = self._index.ntotal - self._index.add(embeddings) + # Step 2: Add new vectors + start_idx = (self._index.value if is_multiprocess else self._index).ntotal + if is_multiprocess: + self._index.value.add(embeddings) + else: + self._index.add(embeddings) - # Step 3: Store metadata + vector for each new ID - for i, meta in enumerate(list_data): - fid = start_idx + i - # Store the raw vector so we can rebuild if something is removed - meta["__vector__"] = embeddings[i].tolist() - self._id_to_meta[fid] = meta + # Step 3: Store metadata + vector for each new ID + for i, meta in enumerate(list_data): + fid = start_idx + i + # Store the raw vector so we can rebuild if something is removed + meta["__vector__"] = embeddings[i].tolist() + self._id_to_meta.update({fid: meta}) - logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") - return [m["__id__"] for m in list_data] + logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") + return [m["__id__"] for m in list_data] async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """ @@ -209,54 +167,57 @@ class FaissVectorDBStorage(BaseVectorStorage): ) # Perform the similarity search - distances, indices = self._index.search(embedding, top_k) + with self._storage_lock: + distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k) - distances = distances[0] - indices = indices[0] + distances = distances[0] + indices = indices[0] - results = [] - for dist, idx in zip(distances, indices): - if idx == -1: - # Faiss returns -1 if no neighbor - continue + results = [] + for dist, idx in zip(distances, indices): + if idx == -1: + # Faiss returns -1 if no neighbor + continue - # Cosine similarity threshold - if dist < self.cosine_better_than_threshold: - continue + # Cosine similarity threshold + if dist < self.cosine_better_than_threshold: + continue - meta = self._id_to_meta.get(idx, {}) - results.append( - { - **meta, - "id": meta.get("__id__"), - "distance": float(dist), - "created_at": meta.get("__created_at__"), - } - ) + meta = self._id_to_meta.get(idx, {}) + results.append( + { + **meta, + "id": meta.get("__id__"), + "distance": float(dist), + "created_at": meta.get("__created_at__"), + } + ) - return results + return results @property def client_storage(self): # Return whatever structure LightRAG might need for debugging - return {"data": list(self._id_to_meta.values())} + with self._storage_lock: + return {"data": list(self._id_to_meta.values())} async def delete(self, ids: list[str]): """ Delete vectors for the provided custom IDs. """ logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") - to_remove = [] - for cid in ids: - fid = self._find_faiss_id_by_custom_id(cid) - if fid is not None: - to_remove.append(fid) + with self._storage_lock: + to_remove = [] + for cid in ids: + fid = self._find_faiss_id_by_custom_id(cid) + if fid is not None: + to_remove.append(fid) - if to_remove: - self._remove_faiss_ids(to_remove) - logger.info( - f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" - ) + if to_remove: + self._remove_faiss_ids(to_remove) + logger.debug( + f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" + ) async def delete_entity(self, entity_name: str) -> None: entity_id = compute_mdhash_id(entity_name, prefix="ent-") @@ -268,18 +229,20 @@ class FaissVectorDBStorage(BaseVectorStorage): Delete relations for a given entity by scanning metadata. """ logger.debug(f"Searching relations for entity {entity_name}") - relations = [] - for fid, meta in self._id_to_meta.items(): - if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: - relations.append(fid) + with self._storage_lock: + relations = [] + for fid, meta in self._id_to_meta.items(): + if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: + relations.append(fid) - logger.debug(f"Found {len(relations)} relations for {entity_name}") - if relations: - self._remove_faiss_ids(relations) - logger.debug(f"Deleted {len(relations)} relations for {entity_name}") + logger.debug(f"Found {len(relations)} relations for {entity_name}") + if relations: + self._remove_faiss_ids(relations) + logger.debug(f"Deleted {len(relations)} relations for {entity_name}") async def index_done_callback(self) -> None: - self._save_faiss_index() + with self._storage_lock: + self._save_faiss_index() # -------------------------------------------------------------------------------- # Internal helper methods @@ -289,10 +252,11 @@ class FaissVectorDBStorage(BaseVectorStorage): """ Return the Faiss internal ID for a given custom ID, or None if not found. """ - for fid, meta in self._id_to_meta.items(): - if meta.get("__id__") == custom_id: - return fid - return None + with self._storage_lock: + for fid, meta in self._id_to_meta.items(): + if meta.get("__id__") == custom_id: + return fid + return None def _remove_faiss_ids(self, fid_list): """ @@ -300,39 +264,45 @@ class FaissVectorDBStorage(BaseVectorStorage): Because IndexFlatIP doesn't support 'removals', we rebuild the index excluding those vectors. """ - keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] + with self._storage_lock: + keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] - # Rebuild the index - vectors_to_keep = [] - new_id_to_meta = {} - for new_fid, old_fid in enumerate(keep_fids): - vec_meta = self._id_to_meta[old_fid] - vectors_to_keep.append(vec_meta["__vector__"]) # stored as list - new_id_to_meta[new_fid] = vec_meta + # Rebuild the index + vectors_to_keep = [] + new_id_to_meta = {} + for new_fid, old_fid in enumerate(keep_fids): + vec_meta = self._id_to_meta[old_fid] + vectors_to_keep.append(vec_meta["__vector__"]) # stored as list + new_id_to_meta[new_fid] = vec_meta - # Re-init index - self._index = faiss.IndexFlatIP(self._dim) - if vectors_to_keep: - arr = np.array(vectors_to_keep, dtype=np.float32) - self._index.add(arr) + # Re-init index + new_index = faiss.IndexFlatIP(self._dim) + if vectors_to_keep: + arr = np.array(vectors_to_keep, dtype=np.float32) + new_index.add(arr) + if is_multiprocess: + self._index.value = new_index + else: + self._index = new_index - self._id_to_meta = new_id_to_meta + self._id_to_meta.update(new_id_to_meta) def _save_faiss_index(self): """ Save the current Faiss index + metadata to disk so it can persist across runs. """ - faiss.write_index(self._index, self._faiss_index_file) + with self._storage_lock: + faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file) - # Save metadata dict to JSON. Convert all keys to strings for JSON storage. - # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } - # We'll keep the int -> dict, but JSON requires string keys. - serializable_dict = {} - for fid, meta in self._id_to_meta.items(): - serializable_dict[str(fid)] = meta + # Save metadata dict to JSON. Convert all keys to strings for JSON storage. + # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } + # We'll keep the int -> dict, but JSON requires string keys. + serializable_dict = {} + for fid, meta in self._id_to_meta.items(): + serializable_dict[str(fid)] = meta - with open(self._meta_file, "w", encoding="utf-8") as f: - json.dump(serializable_dict, f) + with open(self._meta_file, "w", encoding="utf-8") as f: + json.dump(serializable_dict, f) def _load_faiss_index(self): """ @@ -345,22 +315,31 @@ class FaissVectorDBStorage(BaseVectorStorage): try: # Load the Faiss index - self._index = faiss.read_index(self._faiss_index_file) + loaded_index = faiss.read_index(self._faiss_index_file) + if is_multiprocess: + self._index.value = loaded_index + else: + self._index = loaded_index + # Load metadata with open(self._meta_file, "r", encoding="utf-8") as f: stored_dict = json.load(f) # Convert string keys back to int - self._id_to_meta = {} + self._id_to_meta.update({}) for fid_str, meta in stored_dict.items(): fid = int(fid_str) self._id_to_meta[fid] = meta logger.info( - f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}" + f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}" ) except Exception as e: logger.error(f"Failed to load Faiss index or metadata: {e}") logger.warning("Starting with an empty Faiss index.") - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta = {} + new_index = faiss.IndexFlatIP(self._dim) + if is_multiprocess: + self._index.value = new_index + else: + self._index = new_index + self._id_to_meta.update({}) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index dd3a7b64..50451f95 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import os from typing import Any, Union, final -import threading from lightrag.base import ( DocProcessingStatus, @@ -13,26 +12,7 @@ from lightrag.utils import ( logger, write_json, ) -from lightrag.api.utils_api import manager as main_process_manager - -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_doc_status_data = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_doc_status_data - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_doc_status_data = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager +from .shared_storage import get_namespace_data, get_storage_lock @final @@ -43,45 +23,32 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - - # Ensure manager is initialized - _get_manager() - - # Get or create namespace data - if self.namespace not in _shared_doc_status_data: - with _init_lock: - if self.namespace not in _shared_doc_status_data: - try: - initial_data = load_json(self._file_name) or {} - _shared_doc_status_data[self.namespace] = initial_data - except Exception as e: - logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}") - raise RuntimeError(f"Shared data initialization failed: {e}") - - try: - self._data = _shared_doc_status_data[self.namespace] - logger.info(f"Loaded document status storage with {len(self._data)} records") - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + self._storage_lock = get_storage_lock() + self._data = get_namespace_data(self.namespace) + with self._storage_lock: + self._data.update(load_json(self._file_name) or {}) + logger.info(f"Loaded document status storage with {len(self._data)} records") async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - return set(keys) - set(self._data.keys()) + with self._storage_lock: + return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] - for id in ids: - data = self._data.get(id, None) - if data: - result.append(data) + with self._storage_lock: + for id in ids: + data = self._data.get(id, None) + if data: + result.append(data) return result async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" counts = {status.value: 0 for status in DocStatus} - for doc in self._data.values(): - counts[doc["status"]] += 1 + with self._storage_lock: + for doc in self._data.values(): + counts[doc["status"]] += 1 return counts async def get_docs_by_status( @@ -89,39 +56,46 @@ class JsonDocStatusStorage(DocStatusStorage): ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" result = {} - for k, v in self._data.items(): - if v["status"] == status.value: - try: - # Make a copy of the data to avoid modifying the original - data = v.copy() - # If content is missing, use content_summary as content - if "content" not in data and "content_summary" in data: - data["content"] = data["content_summary"] - result[k] = DocProcessingStatus(**data) - except KeyError as e: - logger.error(f"Missing required field for document {k}: {e}") - continue + with self._storage_lock: + for k, v in self._data.items(): + if v["status"] == status.value: + try: + # Make a copy of the data to avoid modifying the original + data = v.copy() + # If content is missing, use content_summary as content + if "content" not in data and "content_summary" in data: + data["content"] = data["content_summary"] + result[k] = DocProcessingStatus(**data) + except KeyError as e: + logger.error(f"Missing required field for document {k}: {e}") + continue return result async def index_done_callback(self) -> None: - write_json(self._data, self._file_name) + # ๆ–‡ไปถๅ†™ๅ…ฅ้œ€่ฆๅŠ ้”๏ผŒ้˜ฒๆญขๅคšไธช่ฟ›็จ‹ๅŒๆ—ถๅ†™ๅ…ฅๅฏผ่‡ดๆ–‡ไปถๆŸๅ + with self._storage_lock: + write_json(self._data, self._file_name) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - self._data.update(data) + with self._storage_lock: + self._data.update(data) await self.index_done_callback() async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - return self._data.get(id) + with self._storage_lock: + return self._data.get(id) async def delete(self, doc_ids: list[str]): - for doc_id in doc_ids: - self._data.pop(doc_id, None) + with self._storage_lock: + for doc_id in doc_ids: + self._data.pop(doc_id, None) await self.index_done_callback() async def drop(self) -> None: """Drop the storage""" - self._data.clear() + with self._storage_lock: + self._data.clear() diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index f5a8b488..a53ac8f0 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,8 +1,6 @@ -import asyncio import os from dataclasses import dataclass from typing import Any, final -import threading from lightrag.base import ( BaseKVStorage, @@ -12,26 +10,7 @@ from lightrag.utils import ( logger, write_json, ) -from lightrag.api.utils_api import manager as main_process_manager - -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_kv_data = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_kv_data - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_kv_data = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager +from .shared_storage import get_namespace_data, get_storage_lock @final @@ -39,57 +18,49 @@ def _get_manager(): class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._lock = asyncio.Lock() - - # Ensure manager is initialized - _get_manager() - - # Get or create namespace data - if self.namespace not in _shared_kv_data: - with _init_lock: - if self.namespace not in _shared_kv_data: - try: - initial_data = load_json(self._file_name) or {} - _shared_kv_data[self.namespace] = initial_data - except Exception as e: - logger.error(f"Failed to initialize shared data for namespace {self.namespace}: {e}") - raise RuntimeError(f"Shared data initialization failed: {e}") - - try: - self._data = _shared_kv_data[self.namespace] - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + self._storage_lock = get_storage_lock() + self._data = get_namespace_data(self.namespace) + with self._storage_lock: + if not self._data: + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data: dict[str, Any] = load_json(self._file_name) or {} + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + async def index_done_callback(self) -> None: - write_json(self._data, self._file_name) + # ๆ–‡ไปถๅ†™ๅ…ฅ้œ€่ฆๅŠ ้”๏ผŒ้˜ฒๆญขๅคšไธช่ฟ›็จ‹ๅŒๆ—ถๅ†™ๅ…ฅๅฏผ่‡ดๆ–‡ไปถๆŸๅ + with self._storage_lock: + write_json(self._data, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: - return self._data.get(id) + with self._storage_lock: + return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - return [ - ( - {k: v for k, v in self._data[id].items()} - if self._data.get(id, None) - else None - ) - for id in ids - ] + with self._storage_lock: + return [ + ( + {k: v for k, v in self._data[id].items()} + if self._data.get(id, None) + else None + ) + for id in ids + ] async def filter_keys(self, keys: set[str]) -> set[str]: - return set(keys) - set(self._data.keys()) + with self._storage_lock: + return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) + with self._storage_lock: + left_data = {k: v for k, v in data.items() if k not in self._data} + self._data.update(left_data) async def delete(self, ids: list[str]) -> None: - for doc_id in ids: - self._data.pop(doc_id, None) + with self._storage_lock: + for doc_id in ids: + self._data.pop(doc_id, None) await self.index_done_callback() diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 7c15142e..07f8d367 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -3,50 +3,29 @@ import os from typing import Any, final from dataclasses import dataclass import numpy as np -import threading import time from lightrag.utils import ( logger, compute_mdhash_id, ) -from lightrag.api.utils_api import manager as main_process_manager import pipmaster as pm -from lightrag.base import ( - BaseVectorStorage, -) +from lightrag.base import BaseVectorStorage +from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_vector_clients = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_vector_clients - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_vector_clients = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager - @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Initialize lock only for file operations - self._save_lock = asyncio.Lock() + self._storage_lock = get_storage_lock() + # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -61,28 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] - # Ensure manager is initialized - _get_manager() + self._client = get_namespace_object(self.namespace) - # Get or create namespace client - if self.namespace not in _shared_vector_clients: - with _init_lock: - if self.namespace not in _shared_vector_clients: - try: - client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name - ) - _shared_vector_clients[self.namespace] = client - except Exception as e: - logger.error(f"Failed to initialize vector DB client for namespace {self.namespace}: {e}") - raise RuntimeError(f"Vector DB client initialization failed: {e}") + with self._storage_lock: + if is_multiprocess: + if self._client.value is None: + self._client.value = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) + else: + if self._client is None: + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) - try: - self._client = _shared_vector_clients[self.namespace] - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + logger.info(f"Initialized vector DB client for namespace {self.namespace}") + + def _get_client(self): + """Get the appropriate client instance based on multiprocess mode""" + if is_multiprocess: + return self._client.value + return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") @@ -104,6 +82,7 @@ class NanoVectorDBStorage(BaseVectorStorage): for i in range(0, len(contents), self._max_batch_size) ] + # Execute embedding outside of lock to avoid long lock times embedding_tasks = [self.embedding_func(batch) for batch in batches] embeddings_list = await asyncio.gather(*embedding_tasks) @@ -111,7 +90,9 @@ class NanoVectorDBStorage(BaseVectorStorage): if len(embeddings) == len(list_data): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - results = self._client.upsert(datas=list_data) + with self._storage_lock: + client = self._get_client() + results = client.upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -120,27 +101,32 @@ class NanoVectorDBStorage(BaseVectorStorage): ) async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + # Execute embedding outside of lock to avoid long lock times embedding = await self.embedding_func([query]) embedding = embedding[0] - results = self._client.query( - query=embedding, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) - results = [ - { - **dp, - "id": dp["__id__"], - "distance": dp["__metrics__"], - "created_at": dp.get("__created_at__"), - } - for dp in results - ] + + with self._storage_lock: + client = self._get_client() + results = client.query( + query=embedding, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + results = [ + { + **dp, + "id": dp["__id__"], + "distance": dp["__metrics__"], + "created_at": dp.get("__created_at__"), + } + for dp in results + ] return results @property def client_storage(self): - return getattr(self._client, "_NanoVectorDB__storage") + client = self._get_client() + return getattr(client, "_NanoVectorDB__storage") async def delete(self, ids: list[str]): """Delete vectors with specified IDs @@ -149,8 +135,10 @@ class NanoVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - self._client.delete(ids) - logger.info( + with self._storage_lock: + client = self._get_client() + client.delete(ids) + logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) except Exception as e: @@ -162,35 +150,42 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.debug( f"Attempting to delete entity {entity_name} with ID {entity_id}" ) - # Check if the entity exists - if self._client.get([entity_id]): - await self.delete([entity_id]) - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") + + with self._storage_lock: + client = self._get_client() + # Check if the entity exists + if client.get([entity_id]): + client.delete([entity_id]) + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: try: - relations = [ - dp - for dp in self.client_storage["data"] - if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name - ] - logger.debug(f"Found {len(relations)} relations for entity {entity_name}") - ids_to_delete = [relation["__id__"] for relation in relations] + with self._storage_lock: + client = self._get_client() + storage = getattr(client, "_NanoVectorDB__storage") + relations = [ + dp + for dp in storage["data"] + if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name + ] + logger.debug(f"Found {len(relations)} relations for entity {entity_name}") + ids_to_delete = [relation["__id__"] for relation in relations] - if ids_to_delete: - await self.delete(ids_to_delete) - logger.debug( - f"Deleted {len(ids_to_delete)} relations for {entity_name}" - ) - else: - logger.debug(f"No relations found for entity {entity_name}") + if ids_to_delete: + client.delete(ids_to_delete) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" + ) + else: + logger.debug(f"No relations found for entity {entity_name}") except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self) -> None: - async with self._save_lock: - self._client.save() + with self._storage_lock: + client = self._get_client() + client.save() diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index f3dd92dc..df07499b 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -1,18 +1,13 @@ import os from dataclasses import dataclass from typing import Any, final -import threading import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from lightrag.utils import ( - logger, -) -from lightrag.api.utils_api import manager as main_process_manager +from lightrag.utils import logger +from lightrag.base import BaseGraphStorage +from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess -from lightrag.base import ( - BaseGraphStorage, -) import pipmaster as pm if not pm.is_installed("networkx"): @@ -24,25 +19,6 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed -# Global variables for shared memory management -_init_lock = threading.Lock() -_manager = None -_shared_graphs = None - - -def _get_manager(): - """Get or create the global manager instance""" - global _manager, _shared_graphs - with _init_lock: - if _manager is None: - try: - _manager = main_process_manager - _shared_graphs = _manager.dict() - except Exception as e: - logger.error(f"Failed to initialize shared memory manager: {e}") - raise RuntimeError(f"Shared memory initialization failed: {e}") - return _manager - @final @dataclass @@ -97,76 +73,98 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) - - # Ensure manager is initialized - _get_manager() - - # Get or create namespace graph - if self.namespace not in _shared_graphs: - with _init_lock: - if self.namespace not in _shared_graphs: - try: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - _shared_graphs[self.namespace] = preloaded_graph or nx.Graph() - except Exception as e: - logger.error(f"Failed to initialize graph for namespace {self.namespace}: {e}") - raise RuntimeError(f"Graph initialization failed: {e}") - - try: - self._graph = _shared_graphs[self.namespace] - self._node_embed_algorithms = { + self._storage_lock = get_storage_lock() + self._graph = get_namespace_object(self.namespace) + with self._storage_lock: + if is_multiprocess: + if self._graph.value is None: + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + self._graph.value = preloaded_graph or nx.Graph() + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + else: + if self._graph is None: + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + self._graph = preloaded_graph or nx.Graph() + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + + self._node_embed_algorithms = { "node2vec": self._node2vec_embed, - } - except Exception as e: - logger.error(f"Failed to access shared memory: {e}") - raise RuntimeError(f"Cannot access shared memory: {e}") + } + + def _get_graph(self): + """Get the appropriate graph instance based on multiprocess mode""" + if is_multiprocess: + return self._graph.value + return self._graph async def index_done_callback(self) -> None: - NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) + with self._storage_lock: + graph = self._get_graph() + NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: - return self._graph.has_node(node_id) + with self._storage_lock: + graph = self._get_graph() + return graph.has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - return self._graph.has_edge(source_node_id, target_node_id) + with self._storage_lock: + graph = self._get_graph() + return graph.has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: - return self._graph.nodes.get(node_id) + with self._storage_lock: + graph = self._get_graph() + return graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: - return self._graph.degree(node_id) + with self._storage_lock: + graph = self._get_graph() + return graph.degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._graph.degree(src_id) + self._graph.degree(tgt_id) + with self._storage_lock: + graph = self._get_graph() + return graph.degree(src_id) + graph.degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - return self._graph.edges.get((source_node_id, target_node_id)) + with self._storage_lock: + graph = self._get_graph() + return graph.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - if self._graph.has_node(source_node_id): - return list(self._graph.edges(source_node_id)) - return None + with self._storage_lock: + graph = self._get_graph() + if graph.has_node(source_node_id): + return list(graph.edges(source_node_id)) + return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - self._graph.add_node(node_id, **node_data) + with self._storage_lock: + graph = self._get_graph() + graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - self._graph.add_edge(source_node_id, target_node_id, **edge_data) + with self._storage_lock: + graph = self._get_graph() + graph.add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: - if self._graph.has_node(node_id): - self._graph.remove_node(node_id) - logger.info(f"Node {node_id} deleted from the graph.") - else: - logger.warning(f"Node {node_id} not found in the graph for deletion.") + with self._storage_lock: + graph = self._get_graph() + if graph.has_node(node_id): + graph.remove_node(node_id) + logger.debug(f"Node {node_id} deleted from the graph.") + else: + logger.warning(f"Node {node_id} not found in the graph for deletion.") async def embed_nodes( self, algorithm: str @@ -175,14 +173,15 @@ class NetworkXStorage(BaseGraphStorage): raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() - # @TODO: NOT USED + # TODO: NOT USED async def _node2vec_embed(self): - embeddings, nodes = embed.node2vec_embed( - self._graph, - **self.global_config["node2vec_params"], - ) - - nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] + with self._storage_lock: + graph = self._get_graph() + embeddings, nodes = embed.node2vec_embed( + graph, + **self.global_config["node2vec_params"], + ) + nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids def remove_nodes(self, nodes: list[str]): @@ -191,9 +190,11 @@ class NetworkXStorage(BaseGraphStorage): Args: nodes: List of node IDs to be deleted """ - for node in nodes: - if self._graph.has_node(node): - self._graph.remove_node(node) + with self._storage_lock: + graph = self._get_graph() + for node in nodes: + if graph.has_node(node): + graph.remove_node(node) def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges @@ -201,9 +202,11 @@ class NetworkXStorage(BaseGraphStorage): Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ - for source, target in edges: - if self._graph.has_edge(source, target): - self._graph.remove_edge(source, target) + with self._storage_lock: + graph = self._get_graph() + for source, target in edges: + if graph.has_edge(source, target): + graph.remove_edge(source, target) async def get_all_labels(self) -> list[str]: """ @@ -211,9 +214,11 @@ class NetworkXStorage(BaseGraphStorage): Returns: [label1, label2, ...] # Alphabetically sorted label list """ - labels = set() - for node in self._graph.nodes(): - labels.add(str(node)) # Add node id as a label + with self._storage_lock: + graph = self._get_graph() + labels = set() + for node in graph.nodes(): + labels.add(str(node)) # Add node id as a label # Return sorted list return sorted(list(labels)) @@ -235,87 +240,86 @@ class NetworkXStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - # Handle special case for "*" label - if node_label == "*": - # For "*", return the entire graph including all nodes and edges - subgraph = ( - self._graph.copy() - ) # Create a copy to avoid modifying the original graph - else: - # Find nodes with matching node id (partial match) - nodes_to_explore = [] - for n, attr in self._graph.nodes(data=True): - if node_label in str(n): # Use partial matching - nodes_to_explore.append(n) + with self._storage_lock: + graph = self._get_graph() + + # Handle special case for "*" label + if node_label == "*": + # For "*", return the entire graph including all nodes and edges + subgraph = graph.copy() # Create a copy to avoid modifying the original graph + else: + # Find nodes with matching node id (partial match) + nodes_to_explore = [] + for n, attr in graph.nodes(data=True): + if node_label in str(n): # Use partial matching + nodes_to_explore.append(n) - if not nodes_to_explore: - logger.warning(f"No nodes found with label {node_label}") - return result + if not nodes_to_explore: + logger.warning(f"No nodes found with label {node_label}") + return result - # Get subgraph using ego_graph - subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth) + # Get subgraph using ego_graph + subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) - # Check if number of nodes exceeds max_graph_nodes - max_graph_nodes = 500 - if len(subgraph.nodes()) > max_graph_nodes: - origin_nodes = len(subgraph.nodes()) - node_degrees = dict(subgraph.degree()) - top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ - :max_graph_nodes - ] - top_node_ids = [node[0] for node in top_nodes] - # Create new subgraph with only top nodes - subgraph = subgraph.subgraph(top_node_ids) - logger.info( - f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" - ) - - # Add nodes to result - for node in subgraph.nodes(): - if str(node) in seen_nodes: - continue - - node_data = dict(subgraph.nodes[node]) - # Get entity_type as labels - labels = [] - if "entity_type" in node_data: - if isinstance(node_data["entity_type"], list): - labels.extend(node_data["entity_type"]) - else: - labels.append(node_data["entity_type"]) - - # Create node with properties - node_properties = {k: v for k, v in node_data.items()} - - result.nodes.append( - KnowledgeGraphNode( - id=str(node), labels=[str(node)], properties=node_properties + # Check if number of nodes exceeds max_graph_nodes + max_graph_nodes = 500 + if len(subgraph.nodes()) > max_graph_nodes: + origin_nodes = len(subgraph.nodes()) + node_degrees = dict(subgraph.degree()) + top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ + :max_graph_nodes + ] + top_node_ids = [node[0] for node in top_nodes] + # Create new subgraph with only top nodes + subgraph = subgraph.subgraph(top_node_ids) + logger.info( + f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" ) - ) - seen_nodes.add(str(node)) - # Add edges to result - for edge in subgraph.edges(): - source, target = edge - edge_id = f"{source}-{target}" - if edge_id in seen_edges: - continue + # Add nodes to result + for node in subgraph.nodes(): + if str(node) in seen_nodes: + continue - edge_data = dict(subgraph.edges[edge]) + node_data = dict(subgraph.nodes[node]) + # Get entity_type as labels + labels = [] + if "entity_type" in node_data: + if isinstance(node_data["entity_type"], list): + labels.extend(node_data["entity_type"]) + else: + labels.append(node_data["entity_type"]) - # Create edge with complete information - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(source), - target=str(target), - properties=edge_data, + # Create node with properties + node_properties = {k: v for k, v in node_data.items()} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node), labels=[str(node)], properties=node_properties + ) ) - ) - seen_edges.add(edge_id) + seen_nodes.add(str(node)) - # logger.info(result.edges) + # Add edges to result + for edge in subgraph.edges(): + source, target = edge + edge_id = f"{source}-{target}" + if edge_id in seen_edges: + continue + + edge_data = dict(subgraph.edges[edge]) + + # Create edge with complete information + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source), + target=str(target), + properties=edge_data, + ) + ) + seen_edges.add(edge_id) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py new file mode 100644 index 00000000..9de3bb79 --- /dev/null +++ b/lightrag/kg/shared_storage.py @@ -0,0 +1,94 @@ +from multiprocessing.synchronize import Lock as ProcessLock +from threading import Lock as ThreadLock +from multiprocessing import Manager +from typing import Any, Dict, Optional, Union + +# ๅฎšไน‰็ฑปๅž‹ๅ˜้‡ +LockType = Union[ProcessLock, ThreadLock] + +# ๅ…จๅฑ€ๅ˜้‡ +_shared_data: Optional[Dict[str, Any]] = None +_namespace_objects: Optional[Dict[str, Any]] = None +_global_lock: Optional[LockType] = None +is_multiprocess = False +manager = None + +def initialize_manager(): + """Initialize manager, only for multiple processes where workers > 1""" + global manager + if manager is None: + manager = Manager() + +def _get_global_lock() -> LockType: + global _global_lock, is_multiprocess + + if _global_lock is None: + if is_multiprocess: + _global_lock = manager.Lock() + else: + _global_lock = ThreadLock() + + return _global_lock + +def get_storage_lock() -> LockType: + """return storage lock for data consistency""" + return _get_global_lock() + +def get_scan_lock() -> LockType: + """return scan_progress lock for data consistency""" + return get_storage_lock() + +def get_shared_data() -> Dict[str, Any]: + """ + return shared data for all storage types + create mult-process save share data only if need for better performance + """ + global _shared_data, is_multiprocess + + if _shared_data is None: + lock = _get_global_lock() + with lock: + if _shared_data is None: + if is_multiprocess: + _shared_data = manager.dict() + else: + _shared_data = {} + + return _shared_data + +def get_namespace_object(namespace: str) -> Any: + """Get an object for specific namespace""" + global _namespace_objects, is_multiprocess + + if _namespace_objects is None: + lock = _get_global_lock() + with lock: + if _namespace_objects is None: + _namespace_objects = {} + + if namespace not in _namespace_objects: + lock = _get_global_lock() + with lock: + if namespace not in _namespace_objects: + if is_multiprocess: + _namespace_objects[namespace] = manager.Value('O', None) + else: + _namespace_objects[namespace] = None + + return _namespace_objects[namespace] + +def get_namespace_data(namespace: str) -> Dict[str, Any]: + """get storage space for specific storage type(namespace)""" + shared_data = get_shared_data() + lock = _get_global_lock() + + if namespace not in shared_data: + with lock: + if namespace not in shared_data: + shared_data[namespace] = {} + + return shared_data[namespace] + +def get_scan_progress() -> Dict[str, Any]: + """get storage space for document scanning progress data""" + return get_namespace_data('scan_progress') diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index c115b33a..d7da6017 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -266,13 +266,7 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) - def __post_init__(self): - # Initialize manager if needed - from lightrag.api.utils_api import manager, initialize_manager - if manager is None: - initialize_manager() - logger.info("Initialized manager for single process mode") - + def __post_init__(self): os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") From 15a6a9cf7c4be8eecf33d9c4dae9ac68ea0e4302 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 12:23:35 +0800 Subject: [PATCH 11/77] fix: log filtering void when uvicorn wokers is greater than 1 - Centralize logging setup - Fix logger propagation issues --- lightrag/api/lightrag_server.py | 86 +++++++++++++++++++++------------ lightrag/utils.py | 29 ++++++----- 2 files changed, 71 insertions(+), 44 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 65227e97..56f55833 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -30,7 +30,6 @@ from lightrag import LightRAG from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc -from lightrag.utils import logger from .routers.document_routes import ( DocumentManager, create_document_routes, @@ -40,36 +39,40 @@ from .routers.query_routes import create_query_routes from .routers.graph_routes import create_graph_routes from .routers.ollama_api import OllamaAPI +from lightrag.utils import logger as utils_logger + # Load environment variables try: load_dotenv(override=True) except Exception as e: - logger.warning(f"Failed to load .env file: {e}") + utils_logger.warning(f"Failed to load .env file: {e}") # Initialize config parser config = configparser.ConfigParser() config.read("config.ini") -class AccessLogFilter(logging.Filter): +class LightragPathFilter(logging.Filter): + """Filter for lightrag logger to filter out frequent path access logs""" def __init__(self): super().__init__() # Define paths to be filtered self.filtered_paths = ["/documents", "/health", "/webui/"] - + def filter(self, record): try: + # Check if record has the required attributes for an access log if not hasattr(record, "args") or not isinstance(record.args, tuple): return True if len(record.args) < 5: return True + # Extract method, path and status from the record args method = record.args[1] path = record.args[2] status = record.args[4] - # print(f"Debug - Method: {method}, Path: {path}, Status: {status}") - # print(f"Debug - Filtered paths: {self.filtered_paths}") + # Filter out successful GET requests to filtered paths if ( method == "GET" and (status == 200 or status == 304) @@ -78,17 +81,23 @@ class AccessLogFilter(logging.Filter): return False return True - except Exception: + # In case of any error, let the message through return True def create_app(args): # Initialize verbose debug setting - from lightrag.utils import set_verbose_debug - + # Can not use the logger at the top of this module when workers > 1 + from lightrag.utils import set_verbose_debug, logger + # Setup logging + logger.setLevel(getattr(logging, args.log_level)) set_verbose_debug(args.verbose) + # Display splash screen + from lightrag.kg.shared_storage import is_multiprocess + logger.info(f"==== Multi-processor mode: {is_multiprocess} ====") + # Verify that bindings are correctly setup if args.llm_binding not in [ "lollms", @@ -120,11 +129,6 @@ def create_app(args): if not os.path.exists(args.ssl_keyfile): raise Exception(f"SSL key file not found: {args.ssl_keyfile}") - # Setup logging - logging.basicConfig( - format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) - ) - # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key @@ -406,6 +410,9 @@ def create_app(args): def get_application(): """Factory function for creating the FastAPI application""" + # Configure logging for this worker process + configure_logging() + # Get args from environment variable args_json = os.environ.get('LIGHTRAG_ARGS') if not args_json: @@ -414,24 +421,22 @@ def get_application(): import types args = types.SimpleNamespace(**json.loads(args_json)) + # if args.workers > 1: + # from lightrag.kg.shared_storage import initialize_manager + # initialize_manager() + return create_app(args) -def main(): - from multiprocessing import freeze_support - freeze_support() +def configure_logging(): + """Configure logging for both uvicorn and lightrag""" + # Reset any existing handlers to ensure clean configuration + for logger_name in ["uvicorn.access", "lightrag"]: + logger = logging.getLogger(logger_name) + logger.handlers = [] + logger.filters = [] - args = parse_args() - # Save args to environment variable for child processes - os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) - - if args.workers > 1: - from lightrag.kg.shared_storage import initialize_manager - initialize_manager() - import lightrag.kg.shared_storage as shared_storage - shared_storage.is_multiprocess = True - - # Configure uvicorn logging + # Configure basic logging logging.config.dictConfig({ "version": 1, "disable_existing_loggers": False, @@ -452,13 +457,32 @@ def main(): "handlers": ["default"], "level": "INFO", "propagate": False, + "filters": ["path_filter"], + }, + "lightrag": { + "handlers": ["default"], + "level": "INFO", + "propagate": False, + "filters": ["path_filter"], + }, + }, + "filters": { + "path_filter": { + "()": "lightrag.api.lightrag_server.LightragPathFilter", }, }, }) - # Add filter to uvicorn access logger - uvicorn_access_logger = logging.getLogger("uvicorn.access") - uvicorn_access_logger.addFilter(AccessLogFilter()) +def main(): + from multiprocessing import freeze_support + freeze_support() + + args = parse_args() + # Save args to environment variable for child processes + os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) + + # Configure logging before starting uvicorn + configure_logging() display_splash_screen(args) diff --git a/lightrag/utils.py b/lightrag/utils.py index e7217def..bc78e2cb 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -55,22 +55,13 @@ def set_verbose_debug(enabled: bool): global VERBOSE_DEBUG VERBOSE_DEBUG = enabled - -class UnlimitedSemaphore: - """A context manager that allows unlimited access.""" - - async def __aenter__(self): - pass - - async def __aexit__(self, exc_type, exc, tb): - pass - - -ENCODER = None - statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} +# Initialize logger logger = logging.getLogger("lightrag") +logger.propagate = False # prevent log message send to root loggger +# Let the main application configure the handlers +logger.setLevel(logging.INFO) # Set httpx logging level to WARNING logging.getLogger("httpx").setLevel(logging.WARNING) @@ -97,6 +88,18 @@ def set_logger(log_file: str, level: int = logging.DEBUG): logger.addHandler(file_handler) +class UnlimitedSemaphore: + """A context manager that allows unlimited access.""" + + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc, tb): + pass + + +ENCODER = None + @dataclass class EmbeddingFunc: embedding_dim: int From 2c019dbc7b46991a9817dff5cfebafb335237699 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 12:28:49 +0800 Subject: [PATCH 12/77] Refactor storage initialization to avoid redundant intitial data loads across processes, show init logs to first load only --- lightrag/kg/faiss_impl.py | 14 +++++++------- lightrag/kg/json_doc_status_impl.py | 5 +++-- lightrag/kg/json_kv_impl.py | 4 ++-- lightrag/kg/nano_vector_db_impl.py | 6 +++--- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 8c9c52c4..3e59d171 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -57,16 +57,16 @@ class FaissVectorDBStorage(BaseVectorStorage): # If you have a large number of vectors, you might want IVF or other indexes. # For demonstration, we use a simple IndexFlatIP. self._index.value = faiss.IndexFlatIP(self._dim) + # Keep a local store for metadata, IDs, etc. + # Maps โ†’ metadata (including your original ID). + self._id_to_meta.update({}) + # Attempt to load an existing index + metadata from disk + self._load_faiss_index() else: if self._index is None: self._index = faiss.IndexFlatIP(self._dim) - - # Keep a local store for metadata, IDs, etc. - # Maps โ†’ metadata (including your original ID). - self._id_to_meta.update({}) - - # Attempt to load an existing index + metadata from disk - self._load_faiss_index() + self._id_to_meta.update({}) + self._load_faiss_index() async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 50451f95..58ee3666 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -26,8 +26,9 @@ class JsonDocStatusStorage(DocStatusStorage): self._storage_lock = get_storage_lock() self._data = get_namespace_data(self.namespace) with self._storage_lock: - self._data.update(load_json(self._file_name) or {}) - logger.info(f"Loaded document status storage with {len(self._data)} records") + if not self._data: + self._data.update(load_json(self._file_name) or {}) + logger.info(f"Loaded document status storage with {len(self._data)} records") async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index a53ac8f0..ee5d8a07 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -18,13 +18,13 @@ from .shared_storage import get_namespace_data, get_storage_lock class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() self._data = get_namespace_data(self.namespace) with self._storage_lock: if not self._data: - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._data: dict[str, Any] = load_json(self._file_name) or {} - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") async def index_done_callback(self) -> None: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 07f8d367..d1682c7a 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -47,14 +47,14 @@ class NanoVectorDBStorage(BaseVectorStorage): if self._client.value is None: self._client.value = NanoVectorDB( self.embedding_func.embedding_dim, storage_file=self._client_file_name - ) + ) + logger.info(f"Initialized vector DB client for namespace {self.namespace}") else: if self._client is None: self._client = NanoVectorDB( self.embedding_func.embedding_dim, storage_file=self._client_file_name ) - - logger.info(f"Initialized vector DB client for namespace {self.namespace}") + logger.info(f"Initialized vector DB client for namespace {self.namespace}") def _get_client(self): """Get the appropriate client instance based on multiprocess mode""" From 41f5d208a9b1204afb9d3bc580a81d633310f9b3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 13:32:15 +0800 Subject: [PATCH 13/77] fix: shared data intitialization failed for multi-worker --- lightrag/api/lightrag_server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 56f55833..07108c52 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -94,7 +94,6 @@ def create_app(args): logger.setLevel(getattr(logging, args.log_level)) set_verbose_debug(args.verbose) - # Display splash screen from lightrag.kg.shared_storage import is_multiprocess logger.info(f"==== Multi-processor mode: {is_multiprocess} ====") @@ -421,9 +420,9 @@ def get_application(): import types args = types.SimpleNamespace(**json.loads(args_json)) - # if args.workers > 1: - # from lightrag.kg.shared_storage import initialize_manager - # initialize_manager() + if args.workers > 1: + from lightrag.kg.shared_storage import initialize_share_data + initialize_share_data() return create_app(args) @@ -486,6 +485,7 @@ def main(): display_splash_screen(args) + uvicorn_config = { "app": "lightrag.api.lightrag_server:get_application", "factory": True, From 145bacc773b049390892d76c19687afbf0c529c9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 17:42:30 +0800 Subject: [PATCH 14/77] Add empty graph creation logging in NetworkXStorage --- lightrag/kg/networkx_impl.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index df07499b..74a6ee28 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -80,16 +80,22 @@ class NetworkXStorage(BaseGraphStorage): if self._graph.value is None: preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) self._graph.value = preloaded_graph or nx.Graph() - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) + if preloaded_graph: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + else: + logger.info("Created new empty graph") else: if self._graph is None: preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) self._graph = preloaded_graph or nx.Graph() - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) + if preloaded_graph: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + else: + logger.info("Created new empty graph") self._node_embed_algorithms = { "node2vec": self._node2vec_embed, From 4eb069d1d67a76cc3f43881b4ac0ca8c922375f4 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 17:42:49 +0800 Subject: [PATCH 15/77] Initialize scan_progress with default values if not already set --- lightrag/api/routers/document_routes.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index c084023d..50bc39df 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -379,8 +379,18 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): scan_progress = get_scan_progress() scan_lock = get_scan_lock() + # Initialize scan_progress if not already initialized + if not scan_progress: + scan_progress.update({ + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + }) + with scan_lock: - if scan_progress["is_scanning"]: + if scan_progress.get("is_scanning", False): ASCIIColors.info( "Skip document scanning(another scanning is active)" ) From 7d12715f098c8b8365a97976452eeab8da9225b1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 18:11:02 +0800 Subject: [PATCH 16/77] Refactor shared storage to safely handle multi-process initialization and data sharing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add namespace initialization check โ€ข Use atomic operations for shared data --- lightrag/kg/json_doc_status_impl.py | 20 +++-- lightrag/kg/shared_storage.py | 128 +++++++++++++++++----------- 2 files changed, 93 insertions(+), 55 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 58ee3666..2a85c68a 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -12,7 +12,11 @@ from lightrag.utils import ( logger, write_json, ) -from .shared_storage import get_namespace_data, get_storage_lock +from .shared_storage import ( + get_namespace_data, + get_storage_lock, + try_initialize_namespace, +) @final @@ -24,11 +28,17 @@ class JsonDocStatusStorage(DocStatusStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() + + # check need_init must before get_namespace_data + need_init = try_initialize_namespace(self.namespace) self._data = get_namespace_data(self.namespace) - with self._storage_lock: - if not self._data: - self._data.update(load_json(self._file_name) or {}) - logger.info(f"Loaded document status storage with {len(self._data)} records") + if need_init: + loaded_data = load_json(self._file_name) or {} + with self._storage_lock: + self._data.update(loaded_data) + logger.info( + f"Loaded document status storage with {len(loaded_data)} records" + ) async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 9de3bb79..27aca9d0 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,30 +1,74 @@ +import os from multiprocessing.synchronize import Lock as ProcessLock from threading import Lock as ThreadLock from multiprocessing import Manager from typing import Any, Dict, Optional, Union +from lightrag.utils import logger -# ๅฎšไน‰็ฑปๅž‹ๅ˜้‡ LockType = Union[ProcessLock, ThreadLock] -# ๅ…จๅฑ€ๅ˜้‡ -_shared_data: Optional[Dict[str, Any]] = None -_namespace_objects: Optional[Dict[str, Any]] = None -_global_lock: Optional[LockType] = None is_multiprocess = False -manager = None -def initialize_manager(): - """Initialize manager, only for multiple processes where workers > 1""" - global manager - if manager is None: - manager = Manager() +_manager = None +_global_lock: Optional[LockType] = None + +# shared data for storage across processes +_shared_dicts: Optional[Dict[str, Any]] = {} +_share_objects: Optional[Dict[str, Any]] = {} +_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized + +def initialize_share_data(): + """Initialize shared data, only called if multiple processes where workers > 1""" + global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess + is_multiprocess = True + + logger.info(f"Process {os.getpid()} initializing shared storage") + + # Initialize manager + if _manager is None: + _manager = Manager() + logger.info(f"Process {os.getpid()} created manager") + + # Create shared dictionaries with manager + _shared_dicts = _manager.dict() + _share_objects = _manager.dict() + _init_flags = _manager.dict() # ไฝฟ็”จๅ…ฑไบซๅญ—ๅ…ธๅญ˜ๅ‚จๅˆๅง‹ๅŒ–ๆ ‡ๅฟ— + logger.info(f"Process {os.getpid()} created shared dictionaries") + +def try_initialize_namespace(namespace: str) -> bool: + """ + ๅฐ่ฏ•ๅˆๅง‹ๅŒ–ๅ‘ฝๅ็ฉบ้—ดใ€‚่ฟ”ๅ›žTrue่กจ็คบๅฝ“ๅ‰่ฟ›็จ‹่Žทๅพ—ไบ†ๅˆๅง‹ๅŒ–ๆƒ้™ใ€‚ + ไฝฟ็”จๅ…ฑไบซๅญ—ๅ…ธ็š„ๅŽŸๅญๆ“ไฝœ็กฎไฟๅชๆœ‰ไธ€ไธช่ฟ›็จ‹่ƒฝๆˆๅŠŸๅˆๅง‹ๅŒ–ใ€‚ + """ + global _init_flags, _manager + + if is_multiprocess: + if _init_flags is None: + raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.") + else: + if _init_flags is None: + _init_flags = {} + + logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") + + # ไฝฟ็”จๅ…จๅฑ€้”ไฟๆŠคๅ…ฑไบซๅญ—ๅ…ธ็š„่ฎฟ้—ฎ + with _get_global_lock(): + # ๆฃ€ๆŸฅๆ˜ฏๅฆๅทฒ็ปๅˆๅง‹ๅŒ– + if namespace not in _init_flags: + # ่ฎพ็ฝฎๅˆๅง‹ๅŒ–ๆ ‡ๅฟ— + _init_flags[namespace] = True + logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}") + return True + + logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized") + return False def _get_global_lock() -> LockType: - global _global_lock, is_multiprocess + global _global_lock, is_multiprocess, _manager if _global_lock is None: if is_multiprocess: - _global_lock = manager.Lock() + _global_lock = _manager.Lock() # Use manager for lock else: _global_lock = ThreadLock() @@ -38,56 +82,40 @@ def get_scan_lock() -> LockType: """return scan_progress lock for data consistency""" return get_storage_lock() -def get_shared_data() -> Dict[str, Any]: - """ - return shared data for all storage types - create mult-process save share data only if need for better performance - """ - global _shared_data, is_multiprocess - - if _shared_data is None: - lock = _get_global_lock() - with lock: - if _shared_data is None: - if is_multiprocess: - _shared_data = manager.dict() - else: - _shared_data = {} - - return _shared_data - def get_namespace_object(namespace: str) -> Any: """Get an object for specific namespace""" - global _namespace_objects, is_multiprocess - - if _namespace_objects is None: + global _share_objects, is_multiprocess, _manager + + if is_multiprocess and not _manager: + raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") + + if namespace not in _share_objects: lock = _get_global_lock() with lock: - if _namespace_objects is None: - _namespace_objects = {} - - if namespace not in _namespace_objects: - lock = _get_global_lock() - with lock: - if namespace not in _namespace_objects: + if namespace not in _share_objects: if is_multiprocess: - _namespace_objects[namespace] = manager.Value('O', None) + _share_objects[namespace] = _manager.Value('O', None) else: - _namespace_objects[namespace] = None + _share_objects[namespace] = None - return _namespace_objects[namespace] + return _share_objects[namespace] + +# ็งป้™คไธๅ†ไฝฟ็”จ็š„ๅ‡ฝๆ•ฐ def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" - shared_data = get_shared_data() - lock = _get_global_lock() + global _shared_dicts, is_multiprocess, _manager - if namespace not in shared_data: + if is_multiprocess and not _manager: + raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") + + if namespace not in _shared_dicts: + lock = _get_global_lock() with lock: - if namespace not in shared_data: - shared_data[namespace] = {} + if namespace not in _shared_dicts: + _shared_dicts[namespace] = {} - return shared_data[namespace] + return _shared_dicts[namespace] def get_scan_progress() -> Dict[str, Any]: """get storage space for document scanning progress data""" From 7436c06f6cb6d92794618d81a3d44bd85952c7ff Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 18:11:16 +0800 Subject: [PATCH 17/77] Fix linting --- .env.example | 2 +- lightrag/api/lightrag_server.py | 90 +++++++++++---------- lightrag/api/routers/document_routes.py | 100 +++++++++++++----------- lightrag/api/utils_api.py | 1 - lightrag/kg/faiss_impl.py | 36 ++++++--- lightrag/kg/json_kv_impl.py | 1 - lightrag/kg/nano_vector_db_impl.py | 28 ++++--- lightrag/kg/networkx_impl.py | 28 ++++--- lightrag/kg/shared_storage.py | 59 +++++++++----- lightrag/lightrag.py | 2 +- lightrag/utils.py | 2 + 11 files changed, 205 insertions(+), 144 deletions(-) diff --git a/.env.example b/.env.example index 0f8e6c31..8a14cdb3 100644 --- a/.env.example +++ b/.env.example @@ -141,4 +141,4 @@ QDRANT_URL=http://localhost:16333 # QDRANT_API_KEY=your-api-key ### Redis -REDIS_URI=redis://localhost:6379 \ No newline at end of file +REDIS_URI=redis://localhost:6379 diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 07108c52..270bbb24 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -54,11 +54,12 @@ config.read("config.ini") class LightragPathFilter(logging.Filter): """Filter for lightrag logger to filter out frequent path access logs""" + def __init__(self): super().__init__() # Define paths to be filtered self.filtered_paths = ["/documents", "/health", "/webui/"] - + def filter(self, record): try: # Check if record has the required attributes for an access log @@ -90,11 +91,13 @@ def create_app(args): # Initialize verbose debug setting # Can not use the logger at the top of this module when workers > 1 from lightrag.utils import set_verbose_debug, logger + # Setup logging logger.setLevel(getattr(logging, args.log_level)) set_verbose_debug(args.verbose) from lightrag.kg.shared_storage import is_multiprocess + logger.info(f"==== Multi-processor mode: {is_multiprocess} ====") # Verify that bindings are correctly setup @@ -147,9 +150,7 @@ def create_app(args): # Auto scan documents if enabled if args.auto_scan_at_startup: # Create background task - task = asyncio.create_task( - run_scanning_process(rag, doc_manager) - ) + task = asyncio.create_task(run_scanning_process(rag, doc_manager)) app.state.background_tasks.add(task) task.add_done_callback(app.state.background_tasks.discard) @@ -411,17 +412,19 @@ def get_application(): """Factory function for creating the FastAPI application""" # Configure logging for this worker process configure_logging() - + # Get args from environment variable - args_json = os.environ.get('LIGHTRAG_ARGS') + args_json = os.environ.get("LIGHTRAG_ARGS") if not args_json: args = parse_args() # Fallback to parsing args if env var not set else: import types + args = types.SimpleNamespace(**json.loads(args_json)) - + if args.workers > 1: from lightrag.kg.shared_storage import initialize_share_data + initialize_share_data() return create_app(args) @@ -434,58 +437,61 @@ def configure_logging(): logger = logging.getLogger(logger_name) logger.handlers = [] logger.filters = [] - + # Configure basic logging - logging.config.dictConfig({ - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "default": { - "format": "%(levelname)s: %(message)s", + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", + }, }, - }, - "handlers": { - "default": { - "formatter": "default", - "class": "logging.StreamHandler", - "stream": "ext://sys.stderr", + "handlers": { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, }, - }, - "loggers": { - "uvicorn.access": { - "handlers": ["default"], - "level": "INFO", - "propagate": False, - "filters": ["path_filter"], + "loggers": { + "uvicorn.access": { + "handlers": ["default"], + "level": "INFO", + "propagate": False, + "filters": ["path_filter"], + }, + "lightrag": { + "handlers": ["default"], + "level": "INFO", + "propagate": False, + "filters": ["path_filter"], + }, }, - "lightrag": { - "handlers": ["default"], - "level": "INFO", - "propagate": False, - "filters": ["path_filter"], + "filters": { + "path_filter": { + "()": "lightrag.api.lightrag_server.LightragPathFilter", + }, }, - }, - "filters": { - "path_filter": { - "()": "lightrag.api.lightrag_server.LightragPathFilter", - }, - }, - }) + } + ) + def main(): from multiprocessing import freeze_support + freeze_support() - + args = parse_args() # Save args to environment variable for child processes - os.environ['LIGHTRAG_ARGS'] = json.dumps(vars(args)) + os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args)) # Configure logging before starting uvicorn configure_logging() display_splash_screen(args) - uvicorn_config = { "app": "lightrag.api.lightrag_server:get_application", "factory": True, diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 50bc39df..1f591750 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -375,62 +375,70 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): - """Background task to scan and index documents""" + """Background task to scan and index documents""" scan_progress = get_scan_progress() scan_lock = get_scan_lock() - + # Initialize scan_progress if not already initialized if not scan_progress: - scan_progress.update({ - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) - + scan_progress.update( + { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) + with scan_lock: if scan_progress.get("is_scanning", False): - ASCIIColors.info( - "Skip document scanning(another scanning is active)" - ) + ASCIIColors.info("Skip document scanning(another scanning is active)") return - scan_progress.update({ - "is_scanning": True, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) + scan_progress.update( + { + "is_scanning": True, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) try: new_files = doc_manager.scan_directory_for_new_files() total_files = len(new_files) - scan_progress.update({ - "current_file": "", - "total_files": total_files, - "indexed_count": 0, - "progress": 0, - }) + scan_progress.update( + { + "current_file": "", + "total_files": total_files, + "indexed_count": 0, + "progress": 0, + } + ) logging.info(f"Found {total_files} new files to index.") for idx, file_path in enumerate(new_files): try: progress = (idx / total_files * 100) if total_files > 0 else 0 - scan_progress.update({ - "current_file": os.path.basename(file_path), - "indexed_count": idx, - "progress": progress, - }) - + scan_progress.update( + { + "current_file": os.path.basename(file_path), + "indexed_count": idx, + "progress": progress, + } + ) + await pipeline_index_file(rag, file_path) - + progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0 - scan_progress.update({ - "current_file": os.path.basename(file_path), - "indexed_count": idx + 1, - "progress": progress, - }) + scan_progress.update( + { + "current_file": os.path.basename(file_path), + "indexed_count": idx + 1, + "progress": progress, + } + ) except Exception as e: logging.error(f"Error indexing file {file_path}: {str(e)}") @@ -438,13 +446,15 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): except Exception as e: logging.error(f"Error during scanning process: {str(e)}") finally: - scan_progress.update({ - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - }) + scan_progress.update( + { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) def create_document_routes( diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 6b501e64..c494101c 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -433,7 +433,6 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.white(" โ””โ”€ Document Status Storage: ", end="") ASCIIColors.yellow(f"{args.doc_status_storage}") - # Server Status ASCIIColors.green("\nโœจ Server starting up...\n") diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 3e59d171..a9d058f4 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -8,14 +8,19 @@ import numpy as np from dataclasses import dataclass import pipmaster as pm -from lightrag.utils import logger,compute_mdhash_id +from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseVectorStorage -from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess +from .shared_storage import ( + get_namespace_data, + get_storage_lock, + get_namespace_object, + is_multiprocess, +) if not pm.is_installed("faiss"): pm.install("faiss") -import faiss # type: ignore +import faiss # type: ignore @final @@ -46,10 +51,10 @@ class FaissVectorDBStorage(BaseVectorStorage): # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim self._storage_lock = get_storage_lock() - - self._index = get_namespace_object('faiss_indices') - self._id_to_meta = get_namespace_data('faiss_meta') - + + self._index = get_namespace_object("faiss_indices") + self._id_to_meta = get_namespace_data("faiss_meta") + with self._storage_lock: if is_multiprocess: if self._index.value is None: @@ -68,7 +73,6 @@ class FaissVectorDBStorage(BaseVectorStorage): self._id_to_meta.update({}) self._load_faiss_index() - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -168,7 +172,9 @@ class FaissVectorDBStorage(BaseVectorStorage): # Perform the similarity search with self._storage_lock: - distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k) + distances, indices = ( + self._index.value if is_multiprocess else self._index + ).search(embedding, top_k) distances = distances[0] indices = indices[0] @@ -232,7 +238,10 @@ class FaissVectorDBStorage(BaseVectorStorage): with self._storage_lock: relations = [] for fid, meta in self._id_to_meta.items(): - if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: + if ( + meta.get("src_id") == entity_name + or meta.get("tgt_id") == entity_name + ): relations.append(fid) logger.debug(f"Found {len(relations)} relations for {entity_name}") @@ -292,7 +301,10 @@ class FaissVectorDBStorage(BaseVectorStorage): Save the current Faiss index + metadata to disk so it can persist across runs. """ with self._storage_lock: - faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file) + faiss.write_index( + self._index.value if is_multiprocess else self._index, + self._faiss_index_file, + ) # Save metadata dict to JSON. Convert all keys to strings for JSON storage. # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } @@ -320,7 +332,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._index.value = loaded_index else: self._index = loaded_index - + # Load metadata with open(self._meta_file, "r", encoding="utf-8") as f: stored_dict = json.load(f) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index ee5d8a07..4c80854a 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -26,7 +26,6 @@ class JsonKVStorage(BaseKVStorage): self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - async def index_done_callback(self) -> None: # ๆ–‡ไปถๅ†™ๅ…ฅ้œ€่ฆๅŠ ้”๏ผŒ้˜ฒๆญขๅคšไธช่ฟ›็จ‹ๅŒๆ—ถๅ†™ๅ…ฅๅฏผ่‡ดๆ–‡ไปถๆŸๅ with self._storage_lock: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index d1682c7a..7707a0f0 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -25,7 +25,7 @@ class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Initialize lock only for file operations self._storage_lock = get_storage_lock() - + # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -39,22 +39,28 @@ class NanoVectorDBStorage(BaseVectorStorage): self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] - + self._client = get_namespace_object(self.namespace) - + with self._storage_lock: if is_multiprocess: if self._client.value is None: self._client.value = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + logger.info( + f"Initialized vector DB client for namespace {self.namespace}" ) - logger.info(f"Initialized vector DB client for namespace {self.namespace}") else: if self._client is None: self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + logger.info( + f"Initialized vector DB client for namespace {self.namespace}" ) - logger.info(f"Initialized vector DB client for namespace {self.namespace}") def _get_client(self): """Get the appropriate client instance based on multiprocess mode""" @@ -104,7 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage): # Execute embedding outside of lock to avoid long lock times embedding = await self.embedding_func([query]) embedding = embedding[0] - + with self._storage_lock: client = self._get_client() results = client.query( @@ -150,7 +156,7 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.debug( f"Attempting to delete entity {entity_name} with ID {entity_id}" ) - + with self._storage_lock: client = self._get_client() # Check if the entity exists @@ -172,7 +178,9 @@ class NanoVectorDBStorage(BaseVectorStorage): for dp in storage["data"] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name ] - logger.debug(f"Found {len(relations)} relations for entity {entity_name}") + logger.debug( + f"Found {len(relations)} relations for entity {entity_name}" + ) ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 74a6ee28..07bd9666 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -78,29 +78,33 @@ class NetworkXStorage(BaseGraphStorage): with self._storage_lock: if is_multiprocess: if self._graph.value is None: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + preloaded_graph = NetworkXStorage.load_nx_graph( + self._graphml_xml_file + ) self._graph.value = preloaded_graph or nx.Graph() if preloaded_graph: logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) else: logger.info("Created new empty graph") else: if self._graph is None: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + preloaded_graph = NetworkXStorage.load_nx_graph( + self._graphml_xml_file + ) self._graph = preloaded_graph or nx.Graph() if preloaded_graph: logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) else: logger.info("Created new empty graph") self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, + "node2vec": self._node2vec_embed, } - + def _get_graph(self): """Get the appropriate graph instance based on multiprocess mode""" if is_multiprocess: @@ -248,11 +252,13 @@ class NetworkXStorage(BaseGraphStorage): with self._storage_lock: graph = self._get_graph() - + # Handle special case for "*" label if node_label == "*": # For "*", return the entire graph including all nodes and edges - subgraph = graph.copy() # Create a copy to avoid modifying the original graph + subgraph = ( + graph.copy() + ) # Create a copy to avoid modifying the original graph else: # Find nodes with matching node id (partial match) nodes_to_explore = [] @@ -272,9 +278,9 @@ class NetworkXStorage(BaseGraphStorage): if len(subgraph.nodes()) > max_graph_nodes: origin_nodes = len(subgraph.nodes()) node_degrees = dict(subgraph.degree()) - top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ - :max_graph_nodes - ] + top_nodes = sorted( + node_degrees.items(), key=lambda x: x[1], reverse=True + )[:max_graph_nodes] top_node_ids = [node[0] for node in top_nodes] # Create new subgraph with only top nodes subgraph = subgraph.subgraph(top_node_ids) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 27aca9d0..bd4c55fe 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -17,106 +17,125 @@ _shared_dicts: Optional[Dict[str, Any]] = {} _share_objects: Optional[Dict[str, Any]] = {} _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized + def initialize_share_data(): """Initialize shared data, only called if multiple processes where workers > 1""" global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess is_multiprocess = True - + logger.info(f"Process {os.getpid()} initializing shared storage") - + # Initialize manager if _manager is None: _manager = Manager() logger.info(f"Process {os.getpid()} created manager") - + # Create shared dictionaries with manager _shared_dicts = _manager.dict() _share_objects = _manager.dict() _init_flags = _manager.dict() # ไฝฟ็”จๅ…ฑไบซๅญ—ๅ…ธๅญ˜ๅ‚จๅˆๅง‹ๅŒ–ๆ ‡ๅฟ— logger.info(f"Process {os.getpid()} created shared dictionaries") + def try_initialize_namespace(namespace: str) -> bool: """ ๅฐ่ฏ•ๅˆๅง‹ๅŒ–ๅ‘ฝๅ็ฉบ้—ดใ€‚่ฟ”ๅ›žTrue่กจ็คบๅฝ“ๅ‰่ฟ›็จ‹่Žทๅพ—ไบ†ๅˆๅง‹ๅŒ–ๆƒ้™ใ€‚ ไฝฟ็”จๅ…ฑไบซๅญ—ๅ…ธ็š„ๅŽŸๅญๆ“ไฝœ็กฎไฟๅชๆœ‰ไธ€ไธช่ฟ›็จ‹่ƒฝๆˆๅŠŸๅˆๅง‹ๅŒ–ใ€‚ """ global _init_flags, _manager - + if is_multiprocess: if _init_flags is None: - raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.") + raise RuntimeError( + "Shared storage not initialized. Call initialize_share_data() first." + ) else: if _init_flags is None: _init_flags = {} - + logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") - + # ไฝฟ็”จๅ…จๅฑ€้”ไฟๆŠคๅ…ฑไบซๅญ—ๅ…ธ็š„่ฎฟ้—ฎ with _get_global_lock(): # ๆฃ€ๆŸฅๆ˜ฏๅฆๅทฒ็ปๅˆๅง‹ๅŒ– if namespace not in _init_flags: # ่ฎพ็ฝฎๅˆๅง‹ๅŒ–ๆ ‡ๅฟ— _init_flags[namespace] = True - logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}") + logger.info( + f"Process {os.getpid()} ready to initialize namespace {namespace}" + ) return True - - logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized") + + logger.info( + f"Process {os.getpid()} found namespace {namespace} already initialized" + ) return False + def _get_global_lock() -> LockType: global _global_lock, is_multiprocess, _manager - + if _global_lock is None: if is_multiprocess: _global_lock = _manager.Lock() # Use manager for lock else: _global_lock = ThreadLock() - + return _global_lock + def get_storage_lock() -> LockType: """return storage lock for data consistency""" return _get_global_lock() + def get_scan_lock() -> LockType: """return scan_progress lock for data consistency""" return get_storage_lock() + def get_namespace_object(namespace: str) -> Any: """Get an object for specific namespace""" global _share_objects, is_multiprocess, _manager - + if is_multiprocess and not _manager: - raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") + raise RuntimeError( + "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first." + ) if namespace not in _share_objects: lock = _get_global_lock() with lock: if namespace not in _share_objects: if is_multiprocess: - _share_objects[namespace] = _manager.Value('O', None) + _share_objects[namespace] = _manager.Value("O", None) else: _share_objects[namespace] = None - + return _share_objects[namespace] + # ็งป้™คไธๅ†ไฝฟ็”จ็š„ๅ‡ฝๆ•ฐ + def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" global _shared_dicts, is_multiprocess, _manager - + if is_multiprocess and not _manager: - raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") + raise RuntimeError( + "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first." + ) if namespace not in _shared_dicts: lock = _get_global_lock() with lock: if namespace not in _shared_dicts: _shared_dicts[namespace] = {} - + return _shared_dicts[namespace] + def get_scan_progress() -> Dict[str, Any]: """get storage space for document scanning progress data""" - return get_namespace_data('scan_progress') + return get_namespace_data("scan_progress") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d7da6017..46638243 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -266,7 +266,7 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) - def __post_init__(self): + def __post_init__(self): os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") diff --git a/lightrag/utils.py b/lightrag/utils.py index bc78e2cb..a6265048 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -55,6 +55,7 @@ def set_verbose_debug(enabled: bool): global VERBOSE_DEBUG VERBOSE_DEBUG = enabled + statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} # Initialize logger @@ -100,6 +101,7 @@ class UnlimitedSemaphore: ENCODER = None + @dataclass class EmbeddingFunc: embedding_dim: int From 7c237920b105643361940d21247339a1c1a3c765 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 08:48:33 +0800 Subject: [PATCH 18/77] Refactor shared storage to support both single and multi-process modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Initialize storage based on worker count โ€ข Remove redundant global variable checks โ€ข Add explicit mutex initialization โ€ข Centralize shared storage initialization โ€ข Fix process/thread lock selection logic --- lightrag/api/lightrag_server.py | 12 ++--- lightrag/kg/shared_storage.py | 87 ++++++++++++++------------------- lightrag/lightrag.py | 3 ++ 3 files changed, 43 insertions(+), 59 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 270bbb24..3af8887d 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -96,10 +96,6 @@ def create_app(args): logger.setLevel(getattr(logging, args.log_level)) set_verbose_debug(args.verbose) - from lightrag.kg.shared_storage import is_multiprocess - - logger.info(f"==== Multi-processor mode: {is_multiprocess} ====") - # Verify that bindings are correctly setup if args.llm_binding not in [ "lollms", @@ -422,11 +418,6 @@ def get_application(): args = types.SimpleNamespace(**json.loads(args_json)) - if args.workers > 1: - from lightrag.kg.shared_storage import initialize_share_data - - initialize_share_data() - return create_app(args) @@ -492,6 +483,9 @@ def main(): display_splash_screen(args) + from lightrag.kg.shared_storage import initialize_share_data + initialize_share_data(args.workers) + uvicorn_config = { "app": "lightrag.api.lightrag_server:get_application", "factory": True, diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index bd4c55fe..6b5c07f6 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -7,35 +7,50 @@ from lightrag.utils import logger LockType = Union[ProcessLock, ThreadLock] -is_multiprocess = False - _manager = None -_global_lock: Optional[LockType] = None +_initialized = None +_is_multiprocess = None +is_multiprocess = None # shared data for storage across processes -_shared_dicts: Optional[Dict[str, Any]] = {} -_share_objects: Optional[Dict[str, Any]] = {} +_shared_dicts: Optional[Dict[str, Any]] = None +_share_objects: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized +_global_lock: Optional[LockType] = None -def initialize_share_data(): - """Initialize shared data, only called if multiple processes where workers > 1""" - global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess - is_multiprocess = True - logger.info(f"Process {os.getpid()} initializing shared storage") +def initialize_share_data(workers: int = 1): + """Initialize storage data""" + global _manager, _is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized + + if _initialized and _initialized.value: + is_multiprocess = _is_multiprocess.value + if _is_multiprocess.value: + logger.info(f"Process {os.getpid()} storage data already initialized!") + return - # Initialize manager - if _manager is None: - _manager = Manager() - logger.info(f"Process {os.getpid()} created manager") + _manager = Manager() + _initialized = _manager.Value("b", False) + _is_multiprocess = _manager.Value("b", False) - # Create shared dictionaries with manager - _shared_dicts = _manager.dict() - _share_objects = _manager.dict() - _init_flags = _manager.dict() # ไฝฟ็”จๅ…ฑไบซๅญ—ๅ…ธๅญ˜ๅ‚จๅˆๅง‹ๅŒ–ๆ ‡ๅฟ— - logger.info(f"Process {os.getpid()} created shared dictionaries") + if workers == 1: + _is_multiprocess.value = False + _global_lock = ThreadLock() + _shared_dicts = {} + _share_objects = {} + _init_flags = {} + logger.info(f"Process {os.getpid()} storage data created for Single Process") + else: + _is_multiprocess.value = True + _global_lock = _manager.Lock() + # Create shared dictionaries with manager + _shared_dicts = _manager.dict() + _share_objects = _manager.dict() + _init_flags = _manager.dict() # ไฝฟ็”จๅ…ฑไบซๅญ—ๅ…ธๅญ˜ๅ‚จๅˆๅง‹ๅŒ–ๆ ‡ๅฟ— + logger.info(f"Process {os.getpid()} storage data created for Multiple Process") + is_multiprocess = _is_multiprocess.value def try_initialize_namespace(namespace: str) -> bool: """ @@ -44,7 +59,7 @@ def try_initialize_namespace(namespace: str) -> bool: """ global _init_flags, _manager - if is_multiprocess: + if _is_multiprocess.value: if _init_flags is None: raise RuntimeError( "Shared storage not initialized. Call initialize_share_data() first." @@ -55,17 +70,13 @@ def try_initialize_namespace(namespace: str) -> bool: logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") - # ไฝฟ็”จๅ…จๅฑ€้”ไฟๆŠคๅ…ฑไบซๅญ—ๅ…ธ็š„่ฎฟ้—ฎ - with _get_global_lock(): - # ๆฃ€ๆŸฅๆ˜ฏๅฆๅทฒ็ปๅˆๅง‹ๅŒ– + with _global_lock: if namespace not in _init_flags: - # ่ฎพ็ฝฎๅˆๅง‹ๅŒ–ๆ ‡ๅฟ— _init_flags[namespace] = True logger.info( f"Process {os.getpid()} ready to initialize namespace {namespace}" ) return True - logger.info( f"Process {os.getpid()} found namespace {namespace} already initialized" ) @@ -73,14 +84,6 @@ def try_initialize_namespace(namespace: str) -> bool: def _get_global_lock() -> LockType: - global _global_lock, is_multiprocess, _manager - - if _global_lock is None: - if is_multiprocess: - _global_lock = _manager.Lock() # Use manager for lock - else: - _global_lock = ThreadLock() - return _global_lock @@ -96,36 +99,20 @@ def get_scan_lock() -> LockType: def get_namespace_object(namespace: str) -> Any: """Get an object for specific namespace""" - global _share_objects, is_multiprocess, _manager - - if is_multiprocess and not _manager: - raise RuntimeError( - "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first." - ) if namespace not in _share_objects: lock = _get_global_lock() with lock: if namespace not in _share_objects: - if is_multiprocess: + if _is_multiprocess.value: _share_objects[namespace] = _manager.Value("O", None) else: _share_objects[namespace] = None return _share_objects[namespace] - -# ็งป้™คไธๅ†ไฝฟ็”จ็š„ๅ‡ฝๆ•ฐ - - def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" - global _shared_dicts, is_multiprocess, _manager - - if is_multiprocess and not _manager: - raise RuntimeError( - "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first." - ) if namespace not in _shared_dicts: lock = _get_global_lock() diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 46638243..08ca202f 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -267,6 +267,9 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): + from lightrag.kg.shared_storage import initialize_share_data + initialize_share_data() + os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") From 7aec78833cb219bd95b4fe46074674955e8feb7e Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 13:25:22 +0800 Subject: [PATCH 19/77] Implement Gunicorn+Uvicorn integration for shared data preloading - Create run_with_gunicorn.py script to properly initialize shared data in the main process before forking worker processes - Revert unvicorn to single process mode only, and let gunicorn do all the multi-process jobs --- gunicorn_config.py | 80 +++++++++++++++ lightrag/api/lightrag_server.py | 23 ++++- lightrag/kg/shared_storage.py | 113 +++++++++++++++++---- run_with_gunicorn.py | 172 ++++++++++++++++++++++++++++++++ 4 files changed, 365 insertions(+), 23 deletions(-) create mode 100644 gunicorn_config.py create mode 100755 run_with_gunicorn.py diff --git a/gunicorn_config.py b/gunicorn_config.py new file mode 100644 index 00000000..8c1b22bf --- /dev/null +++ b/gunicorn_config.py @@ -0,0 +1,80 @@ +# gunicorn_config.py +import os +import multiprocessing +from lightrag.kg.shared_storage import finalize_share_data +from lightrag.api.utils_api import parse_args + +# Parse command line arguments +args = parse_args() + +# Determine worker count - from environment variable or command line arguments +workers = int(os.getenv('WORKERS', args.workers)) + +# If not specified, use CPU count * 2 + 1 (Gunicorn recommended configuration) +if workers <= 1: + workers = multiprocessing.cpu_count() * 2 + 1 + +# Binding address +bind = f"{os.getenv('HOST', args.host)}:{os.getenv('PORT', args.port)}" + +# Enable preload_app option +preload_app = True + +# Use Uvicorn worker +worker_class = "uvicorn.workers.UvicornWorker" + +# Other Gunicorn configurations +timeout = int(os.getenv('TIMEOUT', 120)) +keepalive = 5 + +# Optional SSL configuration +if args.ssl: + certfile = args.ssl_certfile + keyfile = args.ssl_keyfile + +# Logging configuration +errorlog = os.getenv('ERROR_LOG', '-') # '-' means stderr +accesslog = os.getenv('ACCESS_LOG', '-') # '-' means stderr +loglevel = os.getenv('LOG_LEVEL', 'info') + +def on_starting(server): + """ + Executed when Gunicorn starts, before forking the first worker processes + You can use this function to do more initialization tasks for all processes + """ + print("=" * 80) + print(f"GUNICORN MASTER PROCESS: on_starting jobs for all {workers} workers") + print(f"Process ID: {os.getpid()}") + print("=" * 80) + + # Memory usage monitoring + try: + import psutil + process = psutil.Process(os.getpid()) + memory_info = process.memory_info() + msg = f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB" + print(msg) + except ImportError: + print("psutil not installed, skipping memory usage reporting") + + print("=" * 80) + print("Gunicorn initialization complete, forking workers...") + print("=" * 80) + + +def on_exit(server): + """ + Executed when Gunicorn is shutting down. + This is a good place to release shared resources. + """ + print("=" * 80) + print("GUNICORN MASTER PROCESS: Shutting down") + print(f"Process ID: {os.getpid()}") + print("=" * 80) + + # Release shared resources + finalize_share_data() + + print("=" * 80) + print("Gunicorn shutdown complete") + print("=" * 80) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 3af8887d..a9c9ab04 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -483,17 +483,28 @@ def main(): display_splash_screen(args) + # Check if running under Gunicorn + if 'GUNICORN_CMD_ARGS' in os.environ: + # If started with Gunicorn, return directly as Gunicorn will call get_application + print("Running under Gunicorn - worker management handled by Gunicorn") + return + + # If not running under Gunicorn, initialize shared data here from lightrag.kg.shared_storage import initialize_share_data - initialize_share_data(args.workers) - + print("Starting in single-process mode") + initialize_share_data(1) # Force single process mode + + # Create application instance directly instead of using factory function + app = create_app(args) + + # Start Uvicorn in single process mode uvicorn_config = { - "app": "lightrag.api.lightrag_server:get_application", - "factory": True, + "app": app, # Pass application instance directly instead of string path "host": args.host, "port": args.port, - "workers": args.workers, "log_config": None, # Disable default config } + if args.ssl: uvicorn_config.update( { @@ -501,6 +512,8 @@ def main(): "ssl_keyfile": args.ssl_keyfile, } ) + + print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}") uvicorn.run(**uvicorn_config) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 6b5c07f6..8dc9e1a9 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,10 +1,19 @@ import os +import sys from multiprocessing.synchronize import Lock as ProcessLock from threading import Lock as ThreadLock from multiprocessing import Manager from typing import Any, Dict, Optional, Union from lightrag.utils import logger +# Define a direct print function for critical logs that must be visible in all processes +def direct_log(message, level="INFO"): + """ + Log a message directly to stderr to ensure visibility in all processes, + including the Gunicorn master process. + """ + print(f"{level}: {message}", file=sys.stderr, flush=True) + LockType = Union[ProcessLock, ThreadLock] _manager = None @@ -21,41 +30,60 @@ _global_lock: Optional[LockType] = None def initialize_share_data(workers: int = 1): - """Initialize storage data""" - global _manager, _is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized + """ + Initialize shared storage data for single or multi-process mode. + When used with Gunicorn's preload feature, this function is called once in the + master process before forking worker processes, allowing all workers to share + the same initialized data. + + In single-process mode, this function is called during LightRAG object initialization. + + The function determines whether to use cross-process shared variables for data storage + based on the number of workers. If workers=1, it uses thread locks and local dictionaries. + If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager. + + Args: + workers (int): Number of worker processes. If 1, single-process mode is used. + If > 1, multi-process mode with shared memory is used. + """ + global _manager, _is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized + + # Check if already initialized if _initialized and _initialized.value: is_multiprocess = _is_multiprocess.value - if _is_multiprocess.value: - logger.info(f"Process {os.getpid()} storage data already initialized!") - return - + direct_log(f"Process {os.getpid()} storage data already initialized (multiprocess={_is_multiprocess.value})!") + return + _manager = Manager() _initialized = _manager.Value("b", False) _is_multiprocess = _manager.Value("b", False) - if workers == 1: - _is_multiprocess.value = False - _global_lock = ThreadLock() - _shared_dicts = {} - _share_objects = {} - _init_flags = {} - logger.info(f"Process {os.getpid()} storage data created for Single Process") - else: + # Force multi-process mode if workers > 1 + if workers > 1: _is_multiprocess.value = True _global_lock = _manager.Lock() # Create shared dictionaries with manager _shared_dicts = _manager.dict() _share_objects = _manager.dict() - _init_flags = _manager.dict() # ไฝฟ็”จๅ…ฑไบซๅญ—ๅ…ธๅญ˜ๅ‚จๅˆๅง‹ๅŒ–ๆ ‡ๅฟ— - logger.info(f"Process {os.getpid()} storage data created for Multiple Process") + _init_flags = _manager.dict() # Use shared dictionary to store initialization flags + direct_log(f"Process {os.getpid()} storage data created for Multiple Process (workers={workers})") + else: + _is_multiprocess.value = False + _global_lock = ThreadLock() + _shared_dicts = {} + _share_objects = {} + _init_flags = {} + direct_log(f"Process {os.getpid()} storage data created for Single Process") + # Mark as initialized + _initialized.value = True is_multiprocess = _is_multiprocess.value def try_initialize_namespace(namespace: str) -> bool: """ - ๅฐ่ฏ•ๅˆๅง‹ๅŒ–ๅ‘ฝๅ็ฉบ้—ดใ€‚่ฟ”ๅ›žTrue่กจ็คบๅฝ“ๅ‰่ฟ›็จ‹่Žทๅพ—ไบ†ๅˆๅง‹ๅŒ–ๆƒ้™ใ€‚ - ไฝฟ็”จๅ…ฑไบซๅญ—ๅ…ธ็š„ๅŽŸๅญๆ“ไฝœ็กฎไฟๅชๆœ‰ไธ€ไธช่ฟ›็จ‹่ƒฝๆˆๅŠŸๅˆๅง‹ๅŒ–ใ€‚ + Try to initialize a namespace. Returns True if the current process gets initialization permission. + Uses atomic operations on shared dictionaries to ensure only one process can successfully initialize. """ global _init_flags, _manager @@ -126,3 +154,52 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: def get_scan_progress() -> Dict[str, Any]: """get storage space for document scanning progress data""" return get_namespace_data("scan_progress") + + +def finalize_share_data(): + """ + Release shared resources and clean up. + + This function should be called when the application is shutting down + to properly release shared resources and avoid memory leaks. + + In multi-process mode, it shuts down the Manager and releases all shared objects. + In single-process mode, it simply resets the global variables. + """ + global _manager, _is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized + + # Check if already initialized + if not (_initialized and _initialized.value): + direct_log(f"Process {os.getpid()} storage data not initialized, nothing to finalize") + return + + direct_log(f"Process {os.getpid()} finalizing storage data (multiprocess={_is_multiprocess.value})") + + # In multi-process mode, shut down the Manager + if _is_multiprocess.value and _manager is not None: + try: + # Clear shared dictionaries first + if _shared_dicts is not None: + _shared_dicts.clear() + if _share_objects is not None: + _share_objects.clear() + if _init_flags is not None: + _init_flags.clear() + + # Shut down the Manager + _manager.shutdown() + direct_log(f"Process {os.getpid()} Manager shutdown complete") + except Exception as e: + direct_log(f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR") + + # Reset global variables + _manager = None + _initialized = None + _is_multiprocess = None + is_multiprocess = None + _shared_dicts = None + _share_objects = None + _init_flags = None + _global_lock = None + + direct_log(f"Process {os.getpid()} storage data finalization complete") diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py new file mode 100755 index 00000000..44a49e93 --- /dev/null +++ b/run_with_gunicorn.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +""" +Start LightRAG server with Gunicorn +""" +import os +import sys +import json +import signal +import argparse +import subprocess +from lightrag.api.utils_api import parse_args, display_splash_screen +from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data + +# Signal handler for graceful shutdown +def signal_handler(sig, frame): + print("\n\n" + "="*80) + print("RECEIVED TERMINATION SIGNAL") + print(f"Process ID: {os.getpid()}") + print("="*80 + "\n") + + # Release shared resources + finalize_share_data() + + # Exit with success status + sys.exit(0) + +def main(): + # Register signal handlers for graceful shutdown + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # kill command + # Create a parser to handle Gunicorn-specific parameters + parser = argparse.ArgumentParser( + description="Start LightRAG server with Gunicorn" + ) + parser.add_argument( + "--workers", + type=int, + help="Number of worker processes (overrides the default or config.ini setting)" + ) + parser.add_argument( + "--timeout", + type=int, + help="Worker timeout in seconds (default: 120)" + ) + parser.add_argument( + "--log-level", + choices=["debug", "info", "warning", "error", "critical"], + help="Gunicorn log level" + ) + + # Parse Gunicorn-specific arguments + gunicorn_args, remaining_args = parser.parse_known_args() + + # Pass remaining arguments to LightRAG's parse_args + sys.argv = [sys.argv[0]] + remaining_args + args = parse_args() + + # If workers specified, override args value + if gunicorn_args.workers: + args.workers = gunicorn_args.workers + os.environ["WORKERS"] = str(gunicorn_args.workers) + + # If timeout specified, set environment variable + if gunicorn_args.timeout: + os.environ["TIMEOUT"] = str(gunicorn_args.timeout) + + # If log-level specified, set environment variable + if gunicorn_args.log_level: + os.environ["LOG_LEVEL"] = gunicorn_args.log_level + + # Save all LightRAG args to environment variable for worker processes + # This is the key step for passing arguments to lightrag_server.py + os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args)) + + # Display startup information + display_splash_screen(args) + + print("๐Ÿš€ Starting LightRAG with Gunicorn") + print(f"๐Ÿ”„ Worker management: Gunicorn (workers={args.workers})") + print("๐Ÿ” Preloading app: Enabled") + print("๐Ÿ“ Note: Using Gunicorn's preload feature for shared data initialization") + print("\n\n" + "="*80) + print("MAIN PROCESS INITIALIZATION") + print(f"Process ID: {os.getpid()}") + print(f"Workers setting: {args.workers}") + print("="*80 + "\n") + + # Start application with Gunicorn using direct Python API + # Ensure WORKERS environment variable is set before importing gunicorn_config + if args.workers > 1: + os.environ["WORKERS"] = str(args.workers) + + # Import Gunicorn's StandaloneApplication + from gunicorn.app.base import BaseApplication + + # Define a custom application class that loads our config + class GunicornApp(BaseApplication): + def __init__(self, app, options=None): + self.options = options or {} + self.application = app + super().__init__() + + def load_config(self): + # Define valid Gunicorn configuration options + valid_options = { + 'bind', 'workers', 'worker_class', 'timeout', 'keepalive', + 'preload_app', 'errorlog', 'accesslog', 'loglevel', + 'certfile', 'keyfile', 'limit_request_line', 'limit_request_fields', + 'limit_request_field_size', 'graceful_timeout', 'max_requests', + 'max_requests_jitter' + } + + # Special hooks that need to be set separately + special_hooks = { + 'on_starting', 'on_reload', 'on_exit', 'pre_fork', 'post_fork', + 'pre_exec', 'pre_request', 'post_request', 'worker_init', + 'worker_exit', 'nworkers_changed', 'child_exit' + } + + # Import the gunicorn_config module directly + import importlib.util + spec = importlib.util.spec_from_file_location("gunicorn_config", "gunicorn_config.py") + self.config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(self.config_module) + + # Set configuration options + for key in dir(self.config_module): + if key in valid_options: + value = getattr(self.config_module, key) + # Skip functions like on_starting + if not callable(value): + self.cfg.set(key, value) + # Set special hooks + elif key in special_hooks: + value = getattr(self.config_module, key) + if callable(value): + self.cfg.set(key, value) + + # Override with command line arguments if provided + if gunicorn_args.workers: + self.cfg.set("workers", gunicorn_args.workers) + if gunicorn_args.timeout: + self.cfg.set("timeout", gunicorn_args.timeout) + if gunicorn_args.log_level: + self.cfg.set("loglevel", gunicorn_args.log_level) + + def load(self): + # Import the application + from lightrag.api.lightrag_server import get_application + return get_application() + + # Create the application + app = GunicornApp("") + + # Directly call initialize_share_data with the correct workers value + from lightrag.kg.shared_storage import initialize_share_data + + # Force workers to be an integer and greater than 1 for multi-process mode + workers_count = int(args.workers) + if workers_count > 1: + # Set a flag to indicate we're in the main process + os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" + initialize_share_data(workers_count) + else: + initialize_share_data(1) + + # Run the application + print("\nStarting Gunicorn with direct Python API...") + app.run() + +if __name__ == "__main__": + main() From 03d05b094d04a445732281ea9f7ff952fdd9ad1d Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 14:13:42 +0800 Subject: [PATCH 20/77] Improve Gunicorn support and cleanup shared storage initialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Move Gunicorn check before other startup โ€ข Improve startup flow organization --- lightrag/api/lightrag_server.py | 20 +++++++------------- lightrag/kg/shared_storage.py | 2 +- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index a9c9ab04..9f162290 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -470,8 +470,13 @@ def configure_logging(): def main(): - from multiprocessing import freeze_support + # Check if running under Gunicorn + if 'GUNICORN_CMD_ARGS' in os.environ: + # If started with Gunicorn, return directly as Gunicorn will call get_application + print("Running under Gunicorn - worker management handled by Gunicorn") + return + from multiprocessing import freeze_support freeze_support() args = parse_args() @@ -482,18 +487,7 @@ def main(): configure_logging() display_splash_screen(args) - - # Check if running under Gunicorn - if 'GUNICORN_CMD_ARGS' in os.environ: - # If started with Gunicorn, return directly as Gunicorn will call get_application - print("Running under Gunicorn - worker management handled by Gunicorn") - return - - # If not running under Gunicorn, initialize shared data here - from lightrag.kg.shared_storage import initialize_share_data - print("Starting in single-process mode") - initialize_share_data(1) # Force single process mode - + # Create application instance directly instead of using factory function app = create_app(args) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 8dc9e1a9..8956d995 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -52,7 +52,7 @@ def initialize_share_data(workers: int = 1): # Check if already initialized if _initialized and _initialized.value: is_multiprocess = _is_multiprocess.value - direct_log(f"Process {os.getpid()} storage data already initialized (multiprocess={_is_multiprocess.value})!") + direct_log(f"Process {os.getpid()} storage data already initialized (multiprocess={_is_multiprocess.value})") return _manager = Manager() From f007ebf006815a1854595aef665208c928626bc0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 14:55:07 +0800 Subject: [PATCH 21/77] Refactor initialization logic for vector, KV and graph storage implementations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add try_initialize_namespace check โ€ข Move init code out of storage locks โ€ข Reduce redundant init conditions โ€ข Simplify initialization flow โ€ข Make init thread-safer --- lightrag/kg/faiss_impl.py | 31 ++++++++++---------- lightrag/kg/json_kv_impl.py | 14 +++++---- lightrag/kg/nano_vector_db_impl.py | 36 +++++++++++------------ lightrag/kg/networkx_impl.py | 46 ++++++++++++++++-------------- 4 files changed, 67 insertions(+), 60 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index a9d058f4..0315de7c 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -15,6 +15,7 @@ from .shared_storage import ( get_storage_lock, get_namespace_object, is_multiprocess, + try_initialize_namespace, ) if not pm.is_installed("faiss"): @@ -52,26 +53,26 @@ class FaissVectorDBStorage(BaseVectorStorage): self._dim = self.embedding_func.embedding_dim self._storage_lock = get_storage_lock() + # check need_init must before get_namespace_object/get_namespace_data + need_init = try_initialize_namespace("faiss_indices") self._index = get_namespace_object("faiss_indices") self._id_to_meta = get_namespace_data("faiss_meta") - with self._storage_lock: + if need_init: if is_multiprocess: - if self._index.value is None: - # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). - # If you have a large number of vectors, you might want IVF or other indexes. - # For demonstration, we use a simple IndexFlatIP. - self._index.value = faiss.IndexFlatIP(self._dim) - # Keep a local store for metadata, IDs, etc. - # Maps โ†’ metadata (including your original ID). - self._id_to_meta.update({}) - # Attempt to load an existing index + metadata from disk - self._load_faiss_index() + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). + # If you have a large number of vectors, you might want IVF or other indexes. + # For demonstration, we use a simple IndexFlatIP. + self._index.value = faiss.IndexFlatIP(self._dim) + # Keep a local store for metadata, IDs, etc. + # Maps โ†’ metadata (including your original ID). + self._id_to_meta.update({}) + # Attempt to load an existing index + metadata from disk + self._load_faiss_index() else: - if self._index is None: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta.update({}) - self._load_faiss_index() + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta.update({}) + self._load_faiss_index() async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 4c80854a..f13cdfb6 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -10,7 +10,7 @@ from lightrag.utils import ( logger, write_json, ) -from .shared_storage import get_namespace_data, get_storage_lock +from .shared_storage import get_namespace_data, get_storage_lock, try_initialize_namespace @final @@ -20,11 +20,15 @@ class JsonKVStorage(BaseKVStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() + + # check need_init must before get_namespace_data + need_init = try_initialize_namespace(self.namespace) self._data = get_namespace_data(self.namespace) - with self._storage_lock: - if not self._data: - self._data: dict[str, Any] = load_json(self._file_name) or {} - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + if need_init: + loaded_data = load_json(self._file_name) or {} + with self._storage_lock: + self._data.update(loaded_data) + logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") async def index_done_callback(self) -> None: # ๆ–‡ไปถๅ†™ๅ…ฅ้œ€่ฆๅŠ ้”๏ผŒ้˜ฒๆญขๅคšไธช่ฟ›็จ‹ๅŒๆ—ถๅ†™ๅ…ฅๅฏผ่‡ดๆ–‡ไปถๆŸๅ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 7707a0f0..64b0e720 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -11,7 +11,7 @@ from lightrag.utils import ( ) import pipmaster as pm from lightrag.base import BaseVectorStorage -from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess +from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") @@ -40,27 +40,27 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] + # check need_init must before get_namespace_object + need_init = try_initialize_namespace(self.namespace) self._client = get_namespace_object(self.namespace) - with self._storage_lock: + if need_init: if is_multiprocess: - if self._client.value is None: - self._client.value = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - logger.info( - f"Initialized vector DB client for namespace {self.namespace}" - ) + self._client.value = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + logger.info( + f"Initialized vector DB client for namespace {self.namespace}" + ) else: - if self._client is None: - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - logger.info( - f"Initialized vector DB client for namespace {self.namespace}" - ) + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + logger.info( + f"Initialized vector DB client for namespace {self.namespace}" + ) def _get_client(self): """Get the appropriate client instance based on multiprocess mode""" diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 07bd9666..aec49e6c 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -6,7 +6,7 @@ import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from lightrag.base import BaseGraphStorage -from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess +from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace import pipmaster as pm @@ -74,32 +74,34 @@ class NetworkXStorage(BaseGraphStorage): self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) self._storage_lock = get_storage_lock() + + # check need_init must before get_namespace_object + need_init = try_initialize_namespace(self.namespace) self._graph = get_namespace_object(self.namespace) - with self._storage_lock: + + if need_init: if is_multiprocess: - if self._graph.value is None: - preloaded_graph = NetworkXStorage.load_nx_graph( - self._graphml_xml_file + preloaded_graph = NetworkXStorage.load_nx_graph( + self._graphml_xml_file + ) + self._graph.value = preloaded_graph or nx.Graph() + if preloaded_graph: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) - self._graph.value = preloaded_graph or nx.Graph() - if preloaded_graph: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - else: - logger.info("Created new empty graph") + else: + logger.info("Created new empty graph") else: - if self._graph is None: - preloaded_graph = NetworkXStorage.load_nx_graph( - self._graphml_xml_file + preloaded_graph = NetworkXStorage.load_nx_graph( + self._graphml_xml_file + ) + self._graph = preloaded_graph or nx.Graph() + if preloaded_graph: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) - self._graph = preloaded_graph or nx.Graph() - if preloaded_graph: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - else: - logger.info("Created new empty graph") + else: + logger.info("Created new empty graph") self._node_embed_algorithms = { "node2vec": self._node2vec_embed, From 438e4780a8b1f640bde6cd23b3563a1ae1330e1f Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 15:09:19 +0800 Subject: [PATCH 22/77] Refactor Faiss index access with helper method to improve code organization --- lightrag/kg/faiss_impl.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 0315de7c..b6b998e4 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -74,6 +74,13 @@ class FaissVectorDBStorage(BaseVectorStorage): self._id_to_meta.update({}) self._load_faiss_index() + def _get_index(self): + """ + Helper method to get the correct index object based on multiprocess mode. + Returns the actual index object that can be used for operations. + """ + return self._index.value if is_multiprocess else self._index + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -142,11 +149,9 @@ class FaissVectorDBStorage(BaseVectorStorage): self._remove_faiss_ids(existing_ids_to_remove) # Step 2: Add new vectors - start_idx = (self._index.value if is_multiprocess else self._index).ntotal - if is_multiprocess: - self._index.value.add(embeddings) - else: - self._index.add(embeddings) + index = self._get_index() + start_idx = index.ntotal + index.add(embeddings) # Step 3: Store metadata + vector for each new ID for i, meta in enumerate(list_data): @@ -173,9 +178,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Perform the similarity search with self._storage_lock: - distances, indices = ( - self._index.value if is_multiprocess else self._index - ).search(embedding, top_k) + distances, indices = self._get_index().search(embedding, top_k) distances = distances[0] indices = indices[0] @@ -303,7 +306,7 @@ class FaissVectorDBStorage(BaseVectorStorage): """ with self._storage_lock: faiss.write_index( - self._index.value if is_multiprocess else self._index, + self._get_index(), self._faiss_index_file, ) From 1699b10a255c8ab8e72f3738ff2852158baa8bb9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 15:14:54 +0800 Subject: [PATCH 23/77] Refactor direct client/graph access to reduce redundant get calls in vector/graph ops --- lightrag/kg/nano_vector_db_impl.py | 25 +++++++------------ lightrag/kg/networkx_impl.py | 40 +++++++++++------------------- 2 files changed, 23 insertions(+), 42 deletions(-) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 64b0e720..953a19a7 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -97,8 +97,7 @@ class NanoVectorDBStorage(BaseVectorStorage): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] with self._storage_lock: - client = self._get_client() - results = client.upsert(datas=list_data) + results = self._get_client().upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -112,8 +111,7 @@ class NanoVectorDBStorage(BaseVectorStorage): embedding = embedding[0] with self._storage_lock: - client = self._get_client() - results = client.query( + results = self._get_client().query( query=embedding, top_k=top_k, better_than_threshold=self.cosine_better_than_threshold, @@ -131,8 +129,7 @@ class NanoVectorDBStorage(BaseVectorStorage): @property def client_storage(self): - client = self._get_client() - return getattr(client, "_NanoVectorDB__storage") + return getattr(self._get_client(), "_NanoVectorDB__storage") async def delete(self, ids: list[str]): """Delete vectors with specified IDs @@ -142,8 +139,7 @@ class NanoVectorDBStorage(BaseVectorStorage): """ try: with self._storage_lock: - client = self._get_client() - client.delete(ids) + self._get_client().delete(ids) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) @@ -158,10 +154,9 @@ class NanoVectorDBStorage(BaseVectorStorage): ) with self._storage_lock: - client = self._get_client() # Check if the entity exists - if client.get([entity_id]): - client.delete([entity_id]) + if self._get_client().get([entity_id]): + self._get_client().delete([entity_id]) logger.debug(f"Successfully deleted entity {entity_name}") else: logger.debug(f"Entity {entity_name} not found in storage") @@ -171,8 +166,7 @@ class NanoVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: try: with self._storage_lock: - client = self._get_client() - storage = getattr(client, "_NanoVectorDB__storage") + storage = getattr(self._get_client(), "_NanoVectorDB__storage") relations = [ dp for dp in storage["data"] @@ -184,7 +178,7 @@ class NanoVectorDBStorage(BaseVectorStorage): ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: - client.delete(ids_to_delete) + self._get_client().delete(ids_to_delete) logger.debug( f"Deleted {len(ids_to_delete)} relations for {entity_name}" ) @@ -195,5 +189,4 @@ class NanoVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: with self._storage_lock: - client = self._get_client() - client.save() + self._get_client().save() diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index aec49e6c..db059393 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -115,65 +115,54 @@ class NetworkXStorage(BaseGraphStorage): async def index_done_callback(self) -> None: with self._storage_lock: - graph = self._get_graph() - NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file) + NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: with self._storage_lock: - graph = self._get_graph() - return graph.has_node(node_id) + return self._get_graph().has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: with self._storage_lock: - graph = self._get_graph() - return graph.has_edge(source_node_id, target_node_id) + return self._get_graph().has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: with self._storage_lock: - graph = self._get_graph() - return graph.nodes.get(node_id) + return self._get_graph().nodes.get(node_id) async def node_degree(self, node_id: str) -> int: with self._storage_lock: - graph = self._get_graph() - return graph.degree(node_id) + return self._get_graph().degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: with self._storage_lock: - graph = self._get_graph() - return graph.degree(src_id) + graph.degree(tgt_id) + return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: with self._storage_lock: - graph = self._get_graph() - return graph.edges.get((source_node_id, target_node_id)) + return self._get_graph().edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: with self._storage_lock: - graph = self._get_graph() - if graph.has_node(source_node_id): - return list(graph.edges(source_node_id)) + if self._get_graph().has_node(source_node_id): + return list(self._get_graph().edges(source_node_id)) return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: with self._storage_lock: - graph = self._get_graph() - graph.add_node(node_id, **node_data) + self._get_graph().add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: with self._storage_lock: - graph = self._get_graph() - graph.add_edge(source_node_id, target_node_id, **edge_data) + self._get_graph().add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: with self._storage_lock: - graph = self._get_graph() - if graph.has_node(node_id): - graph.remove_node(node_id) + if self._get_graph().has_node(node_id): + self._get_graph().remove_node(node_id) logger.debug(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") @@ -227,9 +216,8 @@ class NetworkXStorage(BaseGraphStorage): [label1, label2, ...] # Alphabetically sorted label list """ with self._storage_lock: - graph = self._get_graph() labels = set() - for node in graph.nodes(): + for node in self._get_graph().nodes(): labels.add(str(node)) # Add node id as a label # Return sorted list From e881bc070947957f92993f6a9ab13ff5c291127b Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 15:36:12 +0800 Subject: [PATCH 24/77] simplify process state management by removing redundant multiprocess flag --- lightrag/kg/shared_storage.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 8956d995..b4bd5613 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -18,7 +18,6 @@ LockType = Union[ProcessLock, ThreadLock] _manager = None _initialized = None -_is_multiprocess = None is_multiprocess = None # shared data for storage across processes @@ -47,21 +46,18 @@ def initialize_share_data(workers: int = 1): workers (int): Number of worker processes. If 1, single-process mode is used. If > 1, multi-process mode with shared memory is used. """ - global _manager, _is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized + global _manager, is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized # Check if already initialized - if _initialized and _initialized.value: - is_multiprocess = _is_multiprocess.value - direct_log(f"Process {os.getpid()} storage data already initialized (multiprocess={_is_multiprocess.value})") + if _initialized: + direct_log(f"Process {os.getpid()} storage data already initialized (multiprocess={is_multiprocess})") return _manager = Manager() - _initialized = _manager.Value("b", False) - _is_multiprocess = _manager.Value("b", False) # Force multi-process mode if workers > 1 if workers > 1: - _is_multiprocess.value = True + is_multiprocess = True _global_lock = _manager.Lock() # Create shared dictionaries with manager _shared_dicts = _manager.dict() @@ -69,7 +65,7 @@ def initialize_share_data(workers: int = 1): _init_flags = _manager.dict() # Use shared dictionary to store initialization flags direct_log(f"Process {os.getpid()} storage data created for Multiple Process (workers={workers})") else: - _is_multiprocess.value = False + is_multiprocess = False _global_lock = ThreadLock() _shared_dicts = {} _share_objects = {} @@ -77,8 +73,7 @@ def initialize_share_data(workers: int = 1): direct_log(f"Process {os.getpid()} storage data created for Single Process") # Mark as initialized - _initialized.value = True - is_multiprocess = _is_multiprocess.value + _initialized = True def try_initialize_namespace(namespace: str) -> bool: """ @@ -87,7 +82,7 @@ def try_initialize_namespace(namespace: str) -> bool: """ global _init_flags, _manager - if _is_multiprocess.value: + if is_multiprocess: if _init_flags is None: raise RuntimeError( "Shared storage not initialized. Call initialize_share_data() first." @@ -132,7 +127,7 @@ def get_namespace_object(namespace: str) -> Any: lock = _get_global_lock() with lock: if namespace not in _share_objects: - if _is_multiprocess.value: + if is_multiprocess: _share_objects[namespace] = _manager.Value("O", None) else: _share_objects[namespace] = None @@ -166,17 +161,17 @@ def finalize_share_data(): In multi-process mode, it shuts down the Manager and releases all shared objects. In single-process mode, it simply resets the global variables. """ - global _manager, _is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized + global _manager, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized # Check if already initialized - if not (_initialized and _initialized.value): + if not _initialized: direct_log(f"Process {os.getpid()} storage data not initialized, nothing to finalize") return - direct_log(f"Process {os.getpid()} finalizing storage data (multiprocess={_is_multiprocess.value})") + direct_log(f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})") # In multi-process mode, shut down the Manager - if _is_multiprocess.value and _manager is not None: + if is_multiprocess and _manager is not None: try: # Clear shared dictionaries first if _shared_dicts is not None: @@ -195,7 +190,6 @@ def finalize_share_data(): # Reset global variables _manager = None _initialized = None - _is_multiprocess = None is_multiprocess = None _shared_dicts = None _share_objects = None From 92ecb0da970432661fe708fbd6793afe3a83abb3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 16:07:00 +0800 Subject: [PATCH 25/77] Refactor document scanning progress share variable initialization --- lightrag/api/routers/document_routes.py | 19 +++---------------- lightrag/lightrag.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 1f591750..5d90fb83 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -20,7 +20,7 @@ from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus from ..utils_api import get_api_key_dependency from lightrag.kg.shared_storage import ( - get_scan_progress, + get_namespace_data, get_scan_lock, ) @@ -376,21 +376,8 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): """Background task to scan and index documents""" - scan_progress = get_scan_progress() + scan_progress = get_namespace_data("scan_progress") scan_lock = get_scan_lock() - - # Initialize scan_progress if not already initialized - if not scan_progress: - scan_progress.update( - { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - } - ) - with scan_lock: if scan_progress.get("is_scanning", False): ASCIIColors.info("Skip document scanning(another scanning is active)") @@ -491,7 +478,7 @@ def create_document_routes( - total_files: Total number of files to process - progress: Percentage of completion """ - return dict(get_scan_progress()) + return dict(get_namespace_data("scan_progress")) @router.post("/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir( diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 08ca202f..ec0accc3 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -267,8 +267,20 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): - from lightrag.kg.shared_storage import initialize_share_data + from lightrag.kg.shared_storage import initialize_share_data, try_initialize_namespace, get_namespace_data initialize_share_data() + need_init = try_initialize_namespace("scan_progress") + scan_progress = get_namespace_data("scan_progress") + if need_init: + scan_progress.update( + { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) From 946095ef8023d2ecf8b3ca34b941fb9bd009fc6c Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 19:03:53 +0800 Subject: [PATCH 26/77] Fix multiprocess dict creation logic, add process safety locks for namespace creation. --- gunicorn_config.py | 1 - lightrag/kg/shared_storage.py | 63 +++++++++++++++++------------------ lightrag/lightrag.py | 34 ++++++++++--------- 3 files changed, 50 insertions(+), 48 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index 8c1b22bf..f4c9178e 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -61,7 +61,6 @@ def on_starting(server): print("Gunicorn initialization complete, forking workers...") print("=" * 80) - def on_exit(server): """ Executed when Gunicorn is shutting down. diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index b4bd5613..73ffb306 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -50,7 +50,7 @@ def initialize_share_data(workers: int = 1): # Check if already initialized if _initialized: - direct_log(f"Process {os.getpid()} storage data already initialized (multiprocess={is_multiprocess})") + direct_log(f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})") return _manager = Manager() @@ -63,14 +63,14 @@ def initialize_share_data(workers: int = 1): _shared_dicts = _manager.dict() _share_objects = _manager.dict() _init_flags = _manager.dict() # Use shared dictionary to store initialization flags - direct_log(f"Process {os.getpid()} storage data created for Multiple Process (workers={workers})") + direct_log(f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})") else: is_multiprocess = False _global_lock = ThreadLock() _shared_dicts = {} _share_objects = {} _init_flags = {} - direct_log(f"Process {os.getpid()} storage data created for Single Process") + direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") # Mark as initialized _initialized = True @@ -82,28 +82,16 @@ def try_initialize_namespace(namespace: str) -> bool: """ global _init_flags, _manager - if is_multiprocess: - if _init_flags is None: - raise RuntimeError( - "Shared storage not initialized. Call initialize_share_data() first." - ) - else: - if _init_flags is None: - _init_flags = {} + if _init_flags is None: + direct_log(f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}", level="ERROR") + raise ValueError("Shared dictionaries not initialized") - logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") - - with _global_lock: - if namespace not in _init_flags: - _init_flags[namespace] = True - logger.info( - f"Process {os.getpid()} ready to initialize namespace {namespace}" - ) - return True - logger.info( - f"Process {os.getpid()} found namespace {namespace} already initialized" - ) - return False + if namespace not in _init_flags: + _init_flags[namespace] = True + direct_log(f"Process {os.getpid()} ready to initialize namespace {namespace}") + return True + direct_log(f"Process {os.getpid()} namespace {namespace} already to initialized") + return False def _get_global_lock() -> LockType: @@ -123,26 +111,37 @@ def get_scan_lock() -> LockType: def get_namespace_object(namespace: str) -> Any: """Get an object for specific namespace""" - if namespace not in _share_objects: - lock = _get_global_lock() - with lock: + if _share_objects is None: + direct_log(f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", level="ERROR") + raise ValueError("Shared dictionaries not initialized") + + lock = _get_global_lock() + with lock: + if namespace not in _share_objects: if namespace not in _share_objects: if is_multiprocess: _share_objects[namespace] = _manager.Value("O", None) else: _share_objects[namespace] = None + direct_log(f"Created namespace({namespace}): type={type(_share_objects[namespace])}, pid={os.getpid()}") return _share_objects[namespace] def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" + if _shared_dicts is None: + direct_log(f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", level="ERROR") + raise ValueError("Shared dictionaries not initialized") - if namespace not in _shared_dicts: - lock = _get_global_lock() - with lock: - if namespace not in _shared_dicts: + lock = _get_global_lock() + with lock: + if namespace not in _shared_dicts: + if is_multiprocess and _manager is not None: + _shared_dicts[namespace] = _manager.dict() + else: _shared_dicts[namespace] = {} - + direct_log(f"Created namespace({namespace}): type={type(_shared_dicts[namespace])}, pid={os.getpid()}") + return _shared_dicts[namespace] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ec0accc3..924fbae3 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -267,25 +267,29 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): - from lightrag.kg.shared_storage import initialize_share_data, try_initialize_namespace, get_namespace_data - initialize_share_data() - need_init = try_initialize_namespace("scan_progress") - scan_progress = get_namespace_data("scan_progress") - if need_init: - scan_progress.update( - { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - } - ) - os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") + from lightrag.kg.shared_storage import initialize_share_data, try_initialize_namespace, get_namespace_data + initialize_share_data() + + need_init = try_initialize_namespace("scan_progress") + scan_progress = get_namespace_data("scan_progress") + logger.info(f"scan_progress type after init: {type(scan_progress)}") + scan_progress.update( + { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) + scan_progress = get_namespace_data("scan_progress") + logger.info(f"scan_progress type after update: {type(scan_progress)}") + logger.info(f"Scan_progres value after update: {scan_progress}") + if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) From 64f22966a3ef5104e4ef288ec2ba0c65f229350c Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 19:05:51 +0800 Subject: [PATCH 27/77] Fix linting --- gunicorn_config.py | 25 +++--- lightrag/api/lightrag_server.py | 11 +-- lightrag/kg/json_kv_impl.py | 8 +- lightrag/kg/nano_vector_db_impl.py | 7 +- lightrag/kg/networkx_impl.py | 19 ++--- lightrag/kg/shared_storage.py | 107 ++++++++++++++++++-------- lightrag/lightrag.py | 14 ++-- run_with_gunicorn.py | 117 +++++++++++++++++------------ 8 files changed, 196 insertions(+), 112 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index f4c9178e..7239acd9 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -8,7 +8,7 @@ from lightrag.api.utils_api import parse_args args = parse_args() # Determine worker count - from environment variable or command line arguments -workers = int(os.getenv('WORKERS', args.workers)) +workers = int(os.getenv("WORKERS", args.workers)) # If not specified, use CPU count * 2 + 1 (Gunicorn recommended configuration) if workers <= 1: @@ -24,7 +24,7 @@ preload_app = True worker_class = "uvicorn.workers.UvicornWorker" # Other Gunicorn configurations -timeout = int(os.getenv('TIMEOUT', 120)) +timeout = int(os.getenv("TIMEOUT", 120)) keepalive = 5 # Optional SSL configuration @@ -33,9 +33,10 @@ if args.ssl: keyfile = args.ssl_keyfile # Logging configuration -errorlog = os.getenv('ERROR_LOG', '-') # '-' means stderr -accesslog = os.getenv('ACCESS_LOG', '-') # '-' means stderr -loglevel = os.getenv('LOG_LEVEL', 'info') +errorlog = os.getenv("ERROR_LOG", "-") # '-' means stderr +accesslog = os.getenv("ACCESS_LOG", "-") # '-' means stderr +loglevel = os.getenv("LOG_LEVEL", "info") + def on_starting(server): """ @@ -46,21 +47,25 @@ def on_starting(server): print(f"GUNICORN MASTER PROCESS: on_starting jobs for all {workers} workers") print(f"Process ID: {os.getpid()}") print("=" * 80) - + # Memory usage monitoring try: import psutil + process = psutil.Process(os.getpid()) memory_info = process.memory_info() - msg = f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB" + msg = ( + f"Memory usage after initialization: {memory_info.rss / 1024 / 1024:.2f} MB" + ) print(msg) except ImportError: print("psutil not installed, skipping memory usage reporting") - + print("=" * 80) print("Gunicorn initialization complete, forking workers...") print("=" * 80) + def on_exit(server): """ Executed when Gunicorn is shutting down. @@ -70,10 +75,10 @@ def on_exit(server): print("GUNICORN MASTER PROCESS: Shutting down") print(f"Process ID: {os.getpid()}") print("=" * 80) - + # Release shared resources finalize_share_data() - + print("=" * 80) print("Gunicorn shutdown complete") print("=" * 80) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 9f162290..155e22f5 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -471,12 +471,13 @@ def configure_logging(): def main(): # Check if running under Gunicorn - if 'GUNICORN_CMD_ARGS' in os.environ: + if "GUNICORN_CMD_ARGS" in os.environ: # If started with Gunicorn, return directly as Gunicorn will call get_application print("Running under Gunicorn - worker management handled by Gunicorn") return from multiprocessing import freeze_support + freeze_support() args = parse_args() @@ -487,10 +488,10 @@ def main(): configure_logging() display_splash_screen(args) - + # Create application instance directly instead of using factory function app = create_app(args) - + # Start Uvicorn in single process mode uvicorn_config = { "app": app, # Pass application instance directly instead of string path @@ -498,7 +499,7 @@ def main(): "port": args.port, "log_config": None, # Disable default config } - + if args.ssl: uvicorn_config.update( { @@ -506,7 +507,7 @@ def main(): "ssl_keyfile": args.ssl_keyfile, } ) - + print(f"Starting Uvicorn server in single-process mode on {args.host}:{args.port}") uvicorn.run(**uvicorn_config) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index f13cdfb6..0d935ebd 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -10,7 +10,11 @@ from lightrag.utils import ( logger, write_json, ) -from .shared_storage import get_namespace_data, get_storage_lock, try_initialize_namespace +from .shared_storage import ( + get_namespace_data, + get_storage_lock, + try_initialize_namespace, +) @final @@ -20,7 +24,7 @@ class JsonKVStorage(BaseKVStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() - + # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) self._data = get_namespace_data(self.namespace) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 953a19a7..43dbcf97 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -11,7 +11,12 @@ from lightrag.utils import ( ) import pipmaster as pm from lightrag.base import BaseVectorStorage -from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace +from .shared_storage import ( + get_storage_lock, + get_namespace_object, + is_multiprocess, + try_initialize_namespace, +) if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index db059393..c42a1981 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -6,7 +6,12 @@ import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from lightrag.base import BaseGraphStorage -from .shared_storage import get_storage_lock, get_namespace_object, is_multiprocess, try_initialize_namespace +from .shared_storage import ( + get_storage_lock, + get_namespace_object, + is_multiprocess, + try_initialize_namespace, +) import pipmaster as pm @@ -74,16 +79,14 @@ class NetworkXStorage(BaseGraphStorage): self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) self._storage_lock = get_storage_lock() - + # check need_init must before get_namespace_object need_init = try_initialize_namespace(self.namespace) self._graph = get_namespace_object(self.namespace) - + if need_init: if is_multiprocess: - preloaded_graph = NetworkXStorage.load_nx_graph( - self._graphml_xml_file - ) + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) self._graph.value = preloaded_graph or nx.Graph() if preloaded_graph: logger.info( @@ -92,9 +95,7 @@ class NetworkXStorage(BaseGraphStorage): else: logger.info("Created new empty graph") else: - preloaded_graph = NetworkXStorage.load_nx_graph( - self._graphml_xml_file - ) + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) self._graph = preloaded_graph or nx.Graph() if preloaded_graph: logger.info( diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 73ffb306..d8cf71c9 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -4,16 +4,17 @@ from multiprocessing.synchronize import Lock as ProcessLock from threading import Lock as ThreadLock from multiprocessing import Manager from typing import Any, Dict, Optional, Union -from lightrag.utils import logger + # Define a direct print function for critical logs that must be visible in all processes def direct_log(message, level="INFO"): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. - """ + """ print(f"{level}: {message}", file=sys.stderr, flush=True) + LockType = Union[ProcessLock, ThreadLock] _manager = None @@ -31,39 +32,53 @@ _global_lock: Optional[LockType] = None def initialize_share_data(workers: int = 1): """ Initialize shared storage data for single or multi-process mode. - + When used with Gunicorn's preload feature, this function is called once in the master process before forking worker processes, allowing all workers to share the same initialized data. - + In single-process mode, this function is called during LightRAG object initialization. - + The function determines whether to use cross-process shared variables for data storage based on the number of workers. If workers=1, it uses thread locks and local dictionaries. If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager. - + Args: workers (int): Number of worker processes. If 1, single-process mode is used. If > 1, multi-process mode with shared memory is used. """ - global _manager, is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized - + global \ + _manager, \ + is_multiprocess, \ + is_multiprocess, \ + _global_lock, \ + _shared_dicts, \ + _share_objects, \ + _init_flags, \ + _initialized + # Check if already initialized if _initialized: - direct_log(f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})") + direct_log( + f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})" + ) return - + _manager = Manager() # Force multi-process mode if workers > 1 if workers > 1: is_multiprocess = True - _global_lock = _manager.Lock() + _global_lock = _manager.Lock() # Create shared dictionaries with manager _shared_dicts = _manager.dict() _share_objects = _manager.dict() - _init_flags = _manager.dict() # Use shared dictionary to store initialization flags - direct_log(f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})") + _init_flags = ( + _manager.dict() + ) # Use shared dictionary to store initialization flags + direct_log( + f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" + ) else: is_multiprocess = False _global_lock = ThreadLock() @@ -75,6 +90,7 @@ def initialize_share_data(workers: int = 1): # Mark as initialized _initialized = True + def try_initialize_namespace(namespace: str) -> bool: """ Try to initialize a namespace. Returns True if the current process gets initialization permission. @@ -83,8 +99,11 @@ def try_initialize_namespace(namespace: str) -> bool: global _init_flags, _manager if _init_flags is None: - direct_log(f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}", level="ERROR") - raise ValueError("Shared dictionaries not initialized") + direct_log( + f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}", + level="ERROR", + ) + raise ValueError("Shared dictionaries not initialized") if namespace not in _init_flags: _init_flags[namespace] = True @@ -112,7 +131,10 @@ def get_namespace_object(namespace: str) -> Any: """Get an object for specific namespace""" if _share_objects is None: - direct_log(f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", level="ERROR") + direct_log( + f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", + level="ERROR", + ) raise ValueError("Shared dictionaries not initialized") lock = _get_global_lock() @@ -123,14 +145,20 @@ def get_namespace_object(namespace: str) -> Any: _share_objects[namespace] = _manager.Value("O", None) else: _share_objects[namespace] = None - direct_log(f"Created namespace({namespace}): type={type(_share_objects[namespace])}, pid={os.getpid()}") + direct_log( + f"Created namespace({namespace}): type={type(_share_objects[namespace])}, pid={os.getpid()}" + ) return _share_objects[namespace] + def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" if _shared_dicts is None: - direct_log(f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", level="ERROR") + direct_log( + f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", + level="ERROR", + ) raise ValueError("Shared dictionaries not initialized") lock = _get_global_lock() @@ -140,8 +168,10 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: _shared_dicts[namespace] = _manager.dict() else: _shared_dicts[namespace] = {} - direct_log(f"Created namespace({namespace}): type={type(_shared_dicts[namespace])}, pid={os.getpid()}") - + direct_log( + f"Created namespace({namespace}): type={type(_shared_dicts[namespace])}, pid={os.getpid()}" + ) + return _shared_dicts[namespace] @@ -153,22 +183,33 @@ def get_scan_progress() -> Dict[str, Any]: def finalize_share_data(): """ Release shared resources and clean up. - + This function should be called when the application is shutting down to properly release shared resources and avoid memory leaks. - + In multi-process mode, it shuts down the Manager and releases all shared objects. In single-process mode, it simply resets the global variables. """ - global _manager, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized - + global \ + _manager, \ + is_multiprocess, \ + _global_lock, \ + _shared_dicts, \ + _share_objects, \ + _init_flags, \ + _initialized + # Check if already initialized if not _initialized: - direct_log(f"Process {os.getpid()} storage data not initialized, nothing to finalize") + direct_log( + f"Process {os.getpid()} storage data not initialized, nothing to finalize" + ) return - - direct_log(f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})") - + + direct_log( + f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})" + ) + # In multi-process mode, shut down the Manager if is_multiprocess and _manager is not None: try: @@ -179,13 +220,15 @@ def finalize_share_data(): _share_objects.clear() if _init_flags is not None: _init_flags.clear() - + # Shut down the Manager _manager.shutdown() direct_log(f"Process {os.getpid()} Manager shutdown complete") except Exception as e: - direct_log(f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR") - + direct_log( + f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR" + ) + # Reset global variables _manager = None _initialized = None @@ -194,5 +237,5 @@ def finalize_share_data(): _share_objects = None _init_flags = None _global_lock = None - + direct_log(f"Process {os.getpid()} storage data finalization complete") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 924fbae3..ae250bac 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -271,12 +271,17 @@ class LightRAG: set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") - from lightrag.kg.shared_storage import initialize_share_data, try_initialize_namespace, get_namespace_data + from lightrag.kg.shared_storage import ( + initialize_share_data, + try_initialize_namespace, + get_namespace_data, + ) + initialize_share_data() - need_init = try_initialize_namespace("scan_progress") + need_init = try_initialize_namespace("scan_progress") scan_progress = get_namespace_data("scan_progress") - logger.info(f"scan_progress type after init: {type(scan_progress)}") + logger.info(f"scan_progress type after init: {type(scan_progress)}") scan_progress.update( { "is_scanning": False, @@ -286,9 +291,6 @@ class LightRAG: "progress": 0, } ) - scan_progress = get_namespace_data("scan_progress") - logger.info(f"scan_progress type after update: {type(scan_progress)}") - logger.info(f"Scan_progres value after update: {scan_progress}") if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index 44a49e93..705cb88f 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -2,127 +2,149 @@ """ Start LightRAG server with Gunicorn """ + import os import sys import json import signal import argparse -import subprocess from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data + # Signal handler for graceful shutdown def signal_handler(sig, frame): - print("\n\n" + "="*80) + print("\n\n" + "=" * 80) print("RECEIVED TERMINATION SIGNAL") print(f"Process ID: {os.getpid()}") - print("="*80 + "\n") - + print("=" * 80 + "\n") + # Release shared resources finalize_share_data() - + # Exit with success status sys.exit(0) + def main(): # Register signal handlers for graceful shutdown signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # kill command + signal.signal(signal.SIGTERM, signal_handler) # kill command # Create a parser to handle Gunicorn-specific parameters - parser = argparse.ArgumentParser( - description="Start LightRAG server with Gunicorn" - ) + parser = argparse.ArgumentParser(description="Start LightRAG server with Gunicorn") parser.add_argument( "--workers", type=int, - help="Number of worker processes (overrides the default or config.ini setting)" + help="Number of worker processes (overrides the default or config.ini setting)", ) parser.add_argument( - "--timeout", - type=int, - help="Worker timeout in seconds (default: 120)" + "--timeout", type=int, help="Worker timeout in seconds (default: 120)" ) parser.add_argument( "--log-level", choices=["debug", "info", "warning", "error", "critical"], - help="Gunicorn log level" + help="Gunicorn log level", ) - + # Parse Gunicorn-specific arguments gunicorn_args, remaining_args = parser.parse_known_args() - + # Pass remaining arguments to LightRAG's parse_args sys.argv = [sys.argv[0]] + remaining_args args = parse_args() - + # If workers specified, override args value if gunicorn_args.workers: args.workers = gunicorn_args.workers os.environ["WORKERS"] = str(gunicorn_args.workers) - + # If timeout specified, set environment variable if gunicorn_args.timeout: os.environ["TIMEOUT"] = str(gunicorn_args.timeout) - + # If log-level specified, set environment variable if gunicorn_args.log_level: os.environ["LOG_LEVEL"] = gunicorn_args.log_level - + # Save all LightRAG args to environment variable for worker processes # This is the key step for passing arguments to lightrag_server.py os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args)) - + # Display startup information display_splash_screen(args) - + print("๐Ÿš€ Starting LightRAG with Gunicorn") print(f"๐Ÿ”„ Worker management: Gunicorn (workers={args.workers})") print("๐Ÿ” Preloading app: Enabled") print("๐Ÿ“ Note: Using Gunicorn's preload feature for shared data initialization") - print("\n\n" + "="*80) + print("\n\n" + "=" * 80) print("MAIN PROCESS INITIALIZATION") print(f"Process ID: {os.getpid()}") print(f"Workers setting: {args.workers}") - print("="*80 + "\n") - + print("=" * 80 + "\n") + # Start application with Gunicorn using direct Python API # Ensure WORKERS environment variable is set before importing gunicorn_config if args.workers > 1: os.environ["WORKERS"] = str(args.workers) - + # Import Gunicorn's StandaloneApplication from gunicorn.app.base import BaseApplication - + # Define a custom application class that loads our config class GunicornApp(BaseApplication): def __init__(self, app, options=None): self.options = options or {} self.application = app super().__init__() - + def load_config(self): # Define valid Gunicorn configuration options valid_options = { - 'bind', 'workers', 'worker_class', 'timeout', 'keepalive', - 'preload_app', 'errorlog', 'accesslog', 'loglevel', - 'certfile', 'keyfile', 'limit_request_line', 'limit_request_fields', - 'limit_request_field_size', 'graceful_timeout', 'max_requests', - 'max_requests_jitter' + "bind", + "workers", + "worker_class", + "timeout", + "keepalive", + "preload_app", + "errorlog", + "accesslog", + "loglevel", + "certfile", + "keyfile", + "limit_request_line", + "limit_request_fields", + "limit_request_field_size", + "graceful_timeout", + "max_requests", + "max_requests_jitter", } - + # Special hooks that need to be set separately special_hooks = { - 'on_starting', 'on_reload', 'on_exit', 'pre_fork', 'post_fork', - 'pre_exec', 'pre_request', 'post_request', 'worker_init', - 'worker_exit', 'nworkers_changed', 'child_exit' + "on_starting", + "on_reload", + "on_exit", + "pre_fork", + "post_fork", + "pre_exec", + "pre_request", + "post_request", + "worker_init", + "worker_exit", + "nworkers_changed", + "child_exit", } - + # Import the gunicorn_config module directly import importlib.util - spec = importlib.util.spec_from_file_location("gunicorn_config", "gunicorn_config.py") + + spec = importlib.util.spec_from_file_location( + "gunicorn_config", "gunicorn_config.py" + ) self.config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(self.config_module) - + # Set configuration options for key in dir(self.config_module): if key in valid_options: @@ -135,7 +157,7 @@ def main(): value = getattr(self.config_module, key) if callable(value): self.cfg.set(key, value) - + # Override with command line arguments if provided if gunicorn_args.workers: self.cfg.set("workers", gunicorn_args.workers) @@ -143,18 +165,18 @@ def main(): self.cfg.set("timeout", gunicorn_args.timeout) if gunicorn_args.log_level: self.cfg.set("loglevel", gunicorn_args.log_level) - + def load(self): # Import the application from lightrag.api.lightrag_server import get_application + return get_application() - + # Create the application app = GunicornApp("") - + # Directly call initialize_share_data with the correct workers value - from lightrag.kg.shared_storage import initialize_share_data - + # Force workers to be an integer and greater than 1 for multi-process mode workers_count = int(args.workers) if workers_count > 1: @@ -163,10 +185,11 @@ def main(): initialize_share_data(workers_count) else: initialize_share_data(1) - + # Run the application print("\nStarting Gunicorn with direct Python API...") app.run() + if __name__ == "__main__": main() From 27500191b40f0cb9b31f3ebb721a7545936c89ff Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 19:08:36 +0800 Subject: [PATCH 28/77] Standarize scan progress namespace initialization --- lightrag/lightrag.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ae250bac..0011fb6f 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -278,19 +278,18 @@ class LightRAG: ) initialize_share_data() - need_init = try_initialize_namespace("scan_progress") scan_progress = get_namespace_data("scan_progress") - logger.info(f"scan_progress type after init: {type(scan_progress)}") - scan_progress.update( - { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - } - ) + if need_init: + scan_progress.update( + { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, + } + ) if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") From 05cf029bcc71a1f879b66efa4494c0224d07c1b9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 20:16:53 +0800 Subject: [PATCH 29/77] fix: convert multiprocessing managed dict to normal dict before JSON dump --- lightrag/kg/json_doc_status_impl.py | 4 ++-- lightrag/kg/json_kv_impl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 2a85c68a..b71cf618 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -83,9 +83,9 @@ class JsonDocStatusStorage(DocStatusStorage): return result async def index_done_callback(self) -> None: - # ๆ–‡ไปถๅ†™ๅ…ฅ้œ€่ฆๅŠ ้”๏ผŒ้˜ฒๆญขๅคšไธช่ฟ›็จ‹ๅŒๆ—ถๅ†™ๅ…ฅๅฏผ่‡ดๆ–‡ไปถๆŸๅ with self._storage_lock: - write_json(self._data, self._file_name) + data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data + write_json(data_dict, self._file_name) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 0d935ebd..c5bff177 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -35,9 +35,9 @@ class JsonKVStorage(BaseKVStorage): logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") async def index_done_callback(self) -> None: - # ๆ–‡ไปถๅ†™ๅ…ฅ้œ€่ฆๅŠ ้”๏ผŒ้˜ฒๆญขๅคšไธช่ฟ›็จ‹ๅŒๆ—ถๅ†™ๅ…ฅๅฏผ่‡ดๆ–‡ไปถๆŸๅ with self._storage_lock: - write_json(self._data, self._file_name) + data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data + write_json(data_dict, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: with self._storage_lock: From 05d03638ecfad35335131c4a335b32170c0f0b0c Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 20:17:28 +0800 Subject: [PATCH 30/77] Clean up logging output and remove redundant log messages --- gunicorn_config.py | 2 -- lightrag/kg/networkx_impl.py | 6 ++---- lightrag/kg/shared_storage.py | 8 ++++---- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index 7239acd9..8500dad6 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -61,9 +61,7 @@ def on_starting(server): except ImportError: print("psutil not installed, skipping memory usage reporting") - print("=" * 80) print("Gunicorn initialization complete, forking workers...") - print("=" * 80) def on_exit(server): diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index c42a1981..d42db33a 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -92,8 +92,6 @@ class NetworkXStorage(BaseGraphStorage): logger.info( f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) - else: - logger.info("Created new empty graph") else: preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) self._graph = preloaded_graph or nx.Graph() @@ -101,8 +99,8 @@ class NetworkXStorage(BaseGraphStorage): logger.info( f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" ) - else: - logger.info("Created new empty graph") + + logger.info("Created new empty graph") self._node_embed_algorithms = { "node2vec": self._node2vec_embed, diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index d8cf71c9..c57771ba 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -107,9 +107,9 @@ def try_initialize_namespace(namespace: str) -> bool: if namespace not in _init_flags: _init_flags[namespace] = True - direct_log(f"Process {os.getpid()} ready to initialize namespace {namespace}") + direct_log(f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]") return True - direct_log(f"Process {os.getpid()} namespace {namespace} already to initialized") + direct_log(f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]") return False @@ -146,7 +146,7 @@ def get_namespace_object(namespace: str) -> Any: else: _share_objects[namespace] = None direct_log( - f"Created namespace({namespace}): type={type(_share_objects[namespace])}, pid={os.getpid()}" + f"Created namespace: {namespace}(type={type(_share_objects[namespace])})" ) return _share_objects[namespace] @@ -169,7 +169,7 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: else: _shared_dicts[namespace] = {} direct_log( - f"Created namespace({namespace}): type={type(_shared_dicts[namespace])}, pid={os.getpid()}" + f"Created namespace: {{namespace}}({type(_shared_dicts[namespace])}) " ) return _shared_dicts[namespace] From cf2f6b726ca13405eade6e54df6df04f6fd0f47b Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 20:26:12 +0800 Subject: [PATCH 31/77] Add newline after Gunicorn initialization message for better readability --- gunicorn_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index 8500dad6..e89b8e12 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -61,7 +61,7 @@ def on_starting(server): except ImportError: print("psutil not installed, skipping memory usage reporting") - print("Gunicorn initialization complete, forking workers...") + print("Gunicorn initialization complete, forking workers...\n") def on_exit(server): From db2a902fcb0c99ac8aae2d0f8688efafe197adb8 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 00:34:33 +0800 Subject: [PATCH 32/77] Rename get_scan_lock to get_storage_lock --- lightrag/api/routers/document_routes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 5d90fb83..2a6459fb 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -21,7 +21,7 @@ from lightrag.base import DocProcessingStatus, DocStatus from ..utils_api import get_api_key_dependency from lightrag.kg.shared_storage import ( get_namespace_data, - get_scan_lock, + get_storage_lock, ) @@ -377,7 +377,7 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): """Background task to scan and index documents""" scan_progress = get_namespace_data("scan_progress") - scan_lock = get_scan_lock() + scan_lock = get_storage_lock() with scan_lock: if scan_progress.get("is_scanning", False): ASCIIColors.info("Skip document scanning(another scanning is active)") From 291e0c1b147bde415f10f0c17556032dc571aec9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 01:14:25 +0800 Subject: [PATCH 33/77] revert vector and graph use local data(single process) --- lightrag/kg/faiss_impl.py | 278 +++++++++++++---------------- lightrag/kg/nano_vector_db_impl.py | 123 +++++-------- lightrag/kg/networkx_impl.py | 268 ++++++++++++--------------- lightrag/kg/shared_storage.py | 61 +------ 4 files changed, 287 insertions(+), 443 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index b6b998e4..a3520653 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -10,19 +10,12 @@ import pipmaster as pm from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseVectorStorage -from .shared_storage import ( - get_namespace_data, - get_storage_lock, - get_namespace_object, - is_multiprocess, - try_initialize_namespace, -) if not pm.is_installed("faiss"): pm.install("faiss") import faiss # type: ignore - +from threading import Lock as ThreadLock @final @dataclass @@ -51,35 +44,29 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim - self._storage_lock = get_storage_lock() + self._storage_lock = ThreadLock() - # check need_init must before get_namespace_object/get_namespace_data - need_init = try_initialize_namespace("faiss_indices") - self._index = get_namespace_object("faiss_indices") - self._id_to_meta = get_namespace_data("faiss_meta") + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). + # If you have a large number of vectors, you might want IVF or other indexes. + # For demonstration, we use a simple IndexFlatIP. + self._index = faiss.IndexFlatIP(self._dim) + + # Keep a local store for metadata, IDs, etc. + # Maps โ†’ metadata (including your original ID). + self._id_to_meta = {} + + # Attempt to load an existing index + metadata from disk + with self._storage_lock: + self._load_faiss_index() - if need_init: - if is_multiprocess: - # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). - # If you have a large number of vectors, you might want IVF or other indexes. - # For demonstration, we use a simple IndexFlatIP. - self._index.value = faiss.IndexFlatIP(self._dim) - # Keep a local store for metadata, IDs, etc. - # Maps โ†’ metadata (including your original ID). - self._id_to_meta.update({}) - # Attempt to load an existing index + metadata from disk - self._load_faiss_index() - else: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta.update({}) - self._load_faiss_index() def _get_index(self): - """ - Helper method to get the correct index object based on multiprocess mode. - Returns the actual index object that can be used for operations. - """ - return self._index.value if is_multiprocess else self._index + """Check if the shtorage should be reloaded""" + return self._index + + async def index_done_callback(self) -> None: + with self._storage_lock: + self._save_faiss_index() async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ @@ -134,34 +121,33 @@ class FaissVectorDBStorage(BaseVectorStorage): # Normalize embeddings for cosine similarity (in-place) faiss.normalize_L2(embeddings) - with self._storage_lock: - # Upsert logic: - # 1. Identify which vectors to remove if they exist - # 2. Remove them - # 3. Add the new vectors - existing_ids_to_remove = [] - for meta, emb in zip(list_data, embeddings): - faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) - if faiss_internal_id is not None: - existing_ids_to_remove.append(faiss_internal_id) + # Upsert logic: + # 1. Identify which vectors to remove if they exist + # 2. Remove them + # 3. Add the new vectors + existing_ids_to_remove = [] + for meta, emb in zip(list_data, embeddings): + faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) + if faiss_internal_id is not None: + existing_ids_to_remove.append(faiss_internal_id) - if existing_ids_to_remove: - self._remove_faiss_ids(existing_ids_to_remove) + if existing_ids_to_remove: + self._remove_faiss_ids(existing_ids_to_remove) - # Step 2: Add new vectors - index = self._get_index() - start_idx = index.ntotal - index.add(embeddings) + # Step 2: Add new vectors + index = self._get_index() + start_idx = index.ntotal + index.add(embeddings) - # Step 3: Store metadata + vector for each new ID - for i, meta in enumerate(list_data): - fid = start_idx + i - # Store the raw vector so we can rebuild if something is removed - meta["__vector__"] = embeddings[i].tolist() - self._id_to_meta.update({fid: meta}) + # Step 3: Store metadata + vector for each new ID + for i, meta in enumerate(list_data): + fid = start_idx + i + # Store the raw vector so we can rebuild if something is removed + meta["__vector__"] = embeddings[i].tolist() + self._id_to_meta.update({fid: meta}) - logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") - return [m["__id__"] for m in list_data] + logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") + return [m["__id__"] for m in list_data] async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """ @@ -177,57 +163,54 @@ class FaissVectorDBStorage(BaseVectorStorage): ) # Perform the similarity search - with self._storage_lock: - distances, indices = self._get_index().search(embedding, top_k) + distances, indices = self._get_index().search(embedding, top_k) - distances = distances[0] - indices = indices[0] + distances = distances[0] + indices = indices[0] - results = [] - for dist, idx in zip(distances, indices): - if idx == -1: - # Faiss returns -1 if no neighbor - continue + results = [] + for dist, idx in zip(distances, indices): + if idx == -1: + # Faiss returns -1 if no neighbor + continue - # Cosine similarity threshold - if dist < self.cosine_better_than_threshold: - continue + # Cosine similarity threshold + if dist < self.cosine_better_than_threshold: + continue - meta = self._id_to_meta.get(idx, {}) - results.append( - { - **meta, - "id": meta.get("__id__"), - "distance": float(dist), - "created_at": meta.get("__created_at__"), - } - ) + meta = self._id_to_meta.get(idx, {}) + results.append( + { + **meta, + "id": meta.get("__id__"), + "distance": float(dist), + "created_at": meta.get("__created_at__"), + } + ) - return results + return results @property def client_storage(self): # Return whatever structure LightRAG might need for debugging - with self._storage_lock: - return {"data": list(self._id_to_meta.values())} + return {"data": list(self._id_to_meta.values())} async def delete(self, ids: list[str]): """ Delete vectors for the provided custom IDs. """ logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") - with self._storage_lock: - to_remove = [] - for cid in ids: - fid = self._find_faiss_id_by_custom_id(cid) - if fid is not None: - to_remove.append(fid) + to_remove = [] + for cid in ids: + fid = self._find_faiss_id_by_custom_id(cid) + if fid is not None: + to_remove.append(fid) - if to_remove: - self._remove_faiss_ids(to_remove) - logger.debug( - f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" - ) + if to_remove: + self._remove_faiss_ids(to_remove) + logger.debug( + f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" + ) async def delete_entity(self, entity_name: str) -> None: entity_id = compute_mdhash_id(entity_name, prefix="ent-") @@ -239,23 +222,18 @@ class FaissVectorDBStorage(BaseVectorStorage): Delete relations for a given entity by scanning metadata. """ logger.debug(f"Searching relations for entity {entity_name}") - with self._storage_lock: - relations = [] - for fid, meta in self._id_to_meta.items(): - if ( - meta.get("src_id") == entity_name - or meta.get("tgt_id") == entity_name - ): - relations.append(fid) + relations = [] + for fid, meta in self._id_to_meta.items(): + if ( + meta.get("src_id") == entity_name + or meta.get("tgt_id") == entity_name + ): + relations.append(fid) - logger.debug(f"Found {len(relations)} relations for {entity_name}") - if relations: - self._remove_faiss_ids(relations) - logger.debug(f"Deleted {len(relations)} relations for {entity_name}") - - async def index_done_callback(self) -> None: - with self._storage_lock: - self._save_faiss_index() + logger.debug(f"Found {len(relations)} relations for {entity_name}") + if relations: + self._remove_faiss_ids(relations) + logger.debug(f"Deleted {len(relations)} relations for {entity_name}") # -------------------------------------------------------------------------------- # Internal helper methods @@ -265,11 +243,10 @@ class FaissVectorDBStorage(BaseVectorStorage): """ Return the Faiss internal ID for a given custom ID, or None if not found. """ - with self._storage_lock: - for fid, meta in self._id_to_meta.items(): - if meta.get("__id__") == custom_id: - return fid - return None + for fid, meta in self._id_to_meta.items(): + if meta.get("__id__") == custom_id: + return fid + return None def _remove_faiss_ids(self, fid_list): """ @@ -277,48 +254,42 @@ class FaissVectorDBStorage(BaseVectorStorage): Because IndexFlatIP doesn't support 'removals', we rebuild the index excluding those vectors. """ + keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] + + # Rebuild the index + vectors_to_keep = [] + new_id_to_meta = {} + for new_fid, old_fid in enumerate(keep_fids): + vec_meta = self._id_to_meta[old_fid] + vectors_to_keep.append(vec_meta["__vector__"]) # stored as list + new_id_to_meta[new_fid] = vec_meta + with self._storage_lock: - keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] - - # Rebuild the index - vectors_to_keep = [] - new_id_to_meta = {} - for new_fid, old_fid in enumerate(keep_fids): - vec_meta = self._id_to_meta[old_fid] - vectors_to_keep.append(vec_meta["__vector__"]) # stored as list - new_id_to_meta[new_fid] = vec_meta - - # Re-init index - new_index = faiss.IndexFlatIP(self._dim) + # Re-init index + self._index = faiss.IndexFlatIP(self._dim) if vectors_to_keep: arr = np.array(vectors_to_keep, dtype=np.float32) - new_index.add(arr) - if is_multiprocess: - self._index.value = new_index - else: - self._index = new_index + self._index.add(arr) + + self._id_to_meta = new_id_to_meta - self._id_to_meta.update(new_id_to_meta) def _save_faiss_index(self): """ Save the current Faiss index + metadata to disk so it can persist across runs. """ - with self._storage_lock: - faiss.write_index( - self._get_index(), - self._faiss_index_file, - ) + faiss.write_index(self._index, self._faiss_index_file) - # Save metadata dict to JSON. Convert all keys to strings for JSON storage. - # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } - # We'll keep the int -> dict, but JSON requires string keys. - serializable_dict = {} - for fid, meta in self._id_to_meta.items(): - serializable_dict[str(fid)] = meta + # Save metadata dict to JSON. Convert all keys to strings for JSON storage. + # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } + # We'll keep the int -> dict, but JSON requires string keys. + serializable_dict = {} + for fid, meta in self._id_to_meta.items(): + serializable_dict[str(fid)] = meta + + with open(self._meta_file, "w", encoding="utf-8") as f: + json.dump(serializable_dict, f) - with open(self._meta_file, "w", encoding="utf-8") as f: - json.dump(serializable_dict, f) def _load_faiss_index(self): """ @@ -331,31 +302,22 @@ class FaissVectorDBStorage(BaseVectorStorage): try: # Load the Faiss index - loaded_index = faiss.read_index(self._faiss_index_file) - if is_multiprocess: - self._index.value = loaded_index - else: - self._index = loaded_index - + self._index = faiss.read_index(self._faiss_index_file) # Load metadata with open(self._meta_file, "r", encoding="utf-8") as f: stored_dict = json.load(f) # Convert string keys back to int - self._id_to_meta.update({}) + self._id_to_meta = {} for fid_str, meta in stored_dict.items(): fid = int(fid_str) self._id_to_meta[fid] = meta logger.info( - f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}" + f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}" ) except Exception as e: logger.error(f"Failed to load Faiss index or metadata: {e}") logger.warning("Starting with an empty Faiss index.") - new_index = faiss.IndexFlatIP(self._dim) - if is_multiprocess: - self._index.value = new_index - else: - self._index = new_index - self._id_to_meta.update({}) + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 43dbcf97..b8fe573d 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -11,25 +11,19 @@ from lightrag.utils import ( ) import pipmaster as pm from lightrag.base import BaseVectorStorage -from .shared_storage import ( - get_storage_lock, - get_namespace_object, - is_multiprocess, - try_initialize_namespace, -) if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB - +from threading import Lock as ThreadLock @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Initialize lock only for file operations - self._storage_lock = get_storage_lock() + self._storage_lock = ThreadLock() # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -45,32 +39,14 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] - # check need_init must before get_namespace_object - need_init = try_initialize_namespace(self.namespace) - self._client = get_namespace_object(self.namespace) - - if need_init: - if is_multiprocess: - self._client.value = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - logger.info( - f"Initialized vector DB client for namespace {self.namespace}" - ) - else: - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - logger.info( - f"Initialized vector DB client for namespace {self.namespace}" - ) + with self._storage_lock: + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) def _get_client(self): - """Get the appropriate client instance based on multiprocess mode""" - if is_multiprocess: - return self._client.value + """Check if the shtorage should be reloaded""" return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: @@ -101,8 +77,7 @@ class NanoVectorDBStorage(BaseVectorStorage): if len(embeddings) == len(list_data): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - with self._storage_lock: - results = self._get_client().upsert(datas=list_data) + results = self._get_client().upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -115,21 +90,20 @@ class NanoVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) embedding = embedding[0] - with self._storage_lock: - results = self._get_client().query( - query=embedding, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) - results = [ - { - **dp, - "id": dp["__id__"], - "distance": dp["__metrics__"], - "created_at": dp.get("__created_at__"), - } - for dp in results - ] + results = self._get_client().query( + query=embedding, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + results = [ + { + **dp, + "id": dp["__id__"], + "distance": dp["__metrics__"], + "created_at": dp.get("__created_at__"), + } + for dp in results + ] return results @property @@ -143,8 +117,7 @@ class NanoVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - with self._storage_lock: - self._get_client().delete(ids) + self._get_client().delete(ids) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) @@ -158,37 +131,35 @@ class NanoVectorDBStorage(BaseVectorStorage): f"Attempting to delete entity {entity_name} with ID {entity_id}" ) - with self._storage_lock: - # Check if the entity exists - if self._get_client().get([entity_id]): - self._get_client().delete([entity_id]) - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") + # Check if the entity exists + if self._get_client().get([entity_id]): + self._get_client().delete([entity_id]) + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: try: - with self._storage_lock: - storage = getattr(self._get_client(), "_NanoVectorDB__storage") - relations = [ - dp - for dp in storage["data"] - if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name - ] - logger.debug( - f"Found {len(relations)} relations for entity {entity_name}" - ) - ids_to_delete = [relation["__id__"] for relation in relations] + storage = getattr(self._get_client(), "_NanoVectorDB__storage") + relations = [ + dp + for dp in storage["data"] + if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name + ] + logger.debug( + f"Found {len(relations)} relations for entity {entity_name}" + ) + ids_to_delete = [relation["__id__"] for relation in relations] - if ids_to_delete: - self._get_client().delete(ids_to_delete) - logger.debug( - f"Deleted {len(ids_to_delete)} relations for {entity_name}" - ) - else: - logger.debug(f"No relations found for entity {entity_name}") + if ids_to_delete: + self._get_client().delete(ids_to_delete) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" + ) + else: + logger.debug(f"No relations found for entity {entity_name}") except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index d42db33a..1f14d5b0 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -6,12 +6,6 @@ import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.utils import logger from lightrag.base import BaseGraphStorage -from .shared_storage import ( - get_storage_lock, - get_namespace_object, - is_multiprocess, - try_initialize_namespace, -) import pipmaster as pm @@ -23,7 +17,7 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed - +from threading import Lock as ThreadLock @final @dataclass @@ -78,38 +72,23 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) - self._storage_lock = get_storage_lock() - - # check need_init must before get_namespace_object - need_init = try_initialize_namespace(self.namespace) - self._graph = get_namespace_object(self.namespace) - - if need_init: - if is_multiprocess: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - self._graph.value = preloaded_graph or nx.Graph() - if preloaded_graph: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - else: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - self._graph = preloaded_graph or nx.Graph() - if preloaded_graph: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) + self._storage_lock = ThreadLock() + with self._storage_lock: + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + if preloaded_graph is not None: + logger.info( + f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" + ) + else: logger.info("Created new empty graph") - + self._graph = preloaded_graph or nx.Graph() self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } def _get_graph(self): - """Get the appropriate graph instance based on multiprocess mode""" - if is_multiprocess: - return self._graph.value + """Check if the shtorage should be reloaded""" return self._graph async def index_done_callback(self) -> None: @@ -117,54 +96,44 @@ class NetworkXStorage(BaseGraphStorage): NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: - with self._storage_lock: - return self._get_graph().has_node(node_id) + return self._get_graph().has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - with self._storage_lock: - return self._get_graph().has_edge(source_node_id, target_node_id) + return self._get_graph().has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: - with self._storage_lock: - return self._get_graph().nodes.get(node_id) + return self._get_graph().nodes.get(node_id) async def node_degree(self, node_id: str) -> int: - with self._storage_lock: - return self._get_graph().degree(node_id) + return self._get_graph().degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: - with self._storage_lock: - return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id) + return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - with self._storage_lock: - return self._get_graph().edges.get((source_node_id, target_node_id)) + return self._get_graph().edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - with self._storage_lock: - if self._get_graph().has_node(source_node_id): - return list(self._get_graph().edges(source_node_id)) - return None + if self._get_graph().has_node(source_node_id): + return list(self._get_graph().edges(source_node_id)) + return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - with self._storage_lock: - self._get_graph().add_node(node_id, **node_data) + self._get_graph().add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - with self._storage_lock: - self._get_graph().add_edge(source_node_id, target_node_id, **edge_data) + self._get_graph().add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: - with self._storage_lock: - if self._get_graph().has_node(node_id): - self._get_graph().remove_node(node_id) - logger.debug(f"Node {node_id} deleted from the graph.") - else: - logger.warning(f"Node {node_id} not found in the graph for deletion.") + if self._get_graph().has_node(node_id): + self._get_graph().remove_node(node_id) + logger.debug(f"Node {node_id} deleted from the graph.") + else: + logger.warning(f"Node {node_id} not found in the graph for deletion.") async def embed_nodes( self, algorithm: str @@ -175,13 +144,12 @@ class NetworkXStorage(BaseGraphStorage): # TODO: NOT USED async def _node2vec_embed(self): - with self._storage_lock: - graph = self._get_graph() - embeddings, nodes = embed.node2vec_embed( - graph, - **self.global_config["node2vec_params"], - ) - nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] + graph = self._get_graph() + embeddings, nodes = embed.node2vec_embed( + graph, + **self.global_config["node2vec_params"], + ) + nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids def remove_nodes(self, nodes: list[str]): @@ -190,11 +158,10 @@ class NetworkXStorage(BaseGraphStorage): Args: nodes: List of node IDs to be deleted """ - with self._storage_lock: - graph = self._get_graph() - for node in nodes: - if graph.has_node(node): - graph.remove_node(node) + graph = self._get_graph() + for node in nodes: + if graph.has_node(node): + graph.remove_node(node) def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges @@ -202,11 +169,10 @@ class NetworkXStorage(BaseGraphStorage): Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ - with self._storage_lock: - graph = self._get_graph() - for source, target in edges: - if graph.has_edge(source, target): - graph.remove_edge(source, target) + graph = self._get_graph() + for source, target in edges: + if graph.has_edge(source, target): + graph.remove_edge(source, target) async def get_all_labels(self) -> list[str]: """ @@ -214,10 +180,9 @@ class NetworkXStorage(BaseGraphStorage): Returns: [label1, label2, ...] # Alphabetically sorted label list """ - with self._storage_lock: - labels = set() - for node in self._get_graph().nodes(): - labels.add(str(node)) # Add node id as a label + labels = set() + for node in self._get_graph().nodes(): + labels.add(str(node)) # Add node id as a label # Return sorted list return sorted(list(labels)) @@ -239,88 +204,87 @@ class NetworkXStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - with self._storage_lock: - graph = self._get_graph() + graph = self._get_graph() - # Handle special case for "*" label - if node_label == "*": - # For "*", return the entire graph including all nodes and edges - subgraph = ( - graph.copy() - ) # Create a copy to avoid modifying the original graph - else: - # Find nodes with matching node id (partial match) - nodes_to_explore = [] - for n, attr in graph.nodes(data=True): - if node_label in str(n): # Use partial matching - nodes_to_explore.append(n) + # Handle special case for "*" label + if node_label == "*": + # For "*", return the entire graph including all nodes and edges + subgraph = ( + graph.copy() + ) # Create a copy to avoid modifying the original graph + else: + # Find nodes with matching node id (partial match) + nodes_to_explore = [] + for n, attr in graph.nodes(data=True): + if node_label in str(n): # Use partial matching + nodes_to_explore.append(n) - if not nodes_to_explore: - logger.warning(f"No nodes found with label {node_label}") - return result + if not nodes_to_explore: + logger.warning(f"No nodes found with label {node_label}") + return result - # Get subgraph using ego_graph - subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) + # Get subgraph using ego_graph + subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) - # Check if number of nodes exceeds max_graph_nodes - max_graph_nodes = 500 - if len(subgraph.nodes()) > max_graph_nodes: - origin_nodes = len(subgraph.nodes()) - node_degrees = dict(subgraph.degree()) - top_nodes = sorted( - node_degrees.items(), key=lambda x: x[1], reverse=True - )[:max_graph_nodes] - top_node_ids = [node[0] for node in top_nodes] - # Create new subgraph with only top nodes - subgraph = subgraph.subgraph(top_node_ids) - logger.info( - f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" + # Check if number of nodes exceeds max_graph_nodes + max_graph_nodes = 500 + if len(subgraph.nodes()) > max_graph_nodes: + origin_nodes = len(subgraph.nodes()) + node_degrees = dict(subgraph.degree()) + top_nodes = sorted( + node_degrees.items(), key=lambda x: x[1], reverse=True + )[:max_graph_nodes] + top_node_ids = [node[0] for node in top_nodes] + # Create new subgraph with only top nodes + subgraph = subgraph.subgraph(top_node_ids) + logger.info( + f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" + ) + + # Add nodes to result + for node in subgraph.nodes(): + if str(node) in seen_nodes: + continue + + node_data = dict(subgraph.nodes[node]) + # Get entity_type as labels + labels = [] + if "entity_type" in node_data: + if isinstance(node_data["entity_type"], list): + labels.extend(node_data["entity_type"]) + else: + labels.append(node_data["entity_type"]) + + # Create node with properties + node_properties = {k: v for k, v in node_data.items()} + + result.nodes.append( + KnowledgeGraphNode( + id=str(node), labels=[str(node)], properties=node_properties ) + ) + seen_nodes.add(str(node)) - # Add nodes to result - for node in subgraph.nodes(): - if str(node) in seen_nodes: - continue + # Add edges to result + for edge in subgraph.edges(): + source, target = edge + edge_id = f"{source}-{target}" + if edge_id in seen_edges: + continue - node_data = dict(subgraph.nodes[node]) - # Get entity_type as labels - labels = [] - if "entity_type" in node_data: - if isinstance(node_data["entity_type"], list): - labels.extend(node_data["entity_type"]) - else: - labels.append(node_data["entity_type"]) + edge_data = dict(subgraph.edges[edge]) - # Create node with properties - node_properties = {k: v for k, v in node_data.items()} - - result.nodes.append( - KnowledgeGraphNode( - id=str(node), labels=[str(node)], properties=node_properties - ) + # Create edge with complete information + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(source), + target=str(target), + properties=edge_data, ) - seen_nodes.add(str(node)) - - # Add edges to result - for edge in subgraph.edges(): - source, target = edge - edge_id = f"{source}-{target}" - if edge_id in seen_edges: - continue - - edge_data = dict(subgraph.edges[edge]) - - # Create edge with complete information - result.edges.append( - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=str(source), - target=str(target), - properties=edge_data, - ) - ) - seen_edges.add(edge_id) + ) + seen_edges.add(edge_id) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index c57771ba..681ef064 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -20,15 +20,12 @@ LockType = Union[ProcessLock, ThreadLock] _manager = None _initialized = None is_multiprocess = None +_global_lock: Optional[LockType] = None # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None -_share_objects: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized -_global_lock: Optional[LockType] = None - - def initialize_share_data(workers: int = 1): """ Initialize shared storage data for single or multi-process mode. @@ -53,7 +50,6 @@ def initialize_share_data(workers: int = 1): is_multiprocess, \ _global_lock, \ _shared_dicts, \ - _share_objects, \ _init_flags, \ _initialized @@ -72,7 +68,6 @@ def initialize_share_data(workers: int = 1): _global_lock = _manager.Lock() # Create shared dictionaries with manager _shared_dicts = _manager.dict() - _share_objects = _manager.dict() _init_flags = ( _manager.dict() ) # Use shared dictionary to store initialization flags @@ -83,7 +78,6 @@ def initialize_share_data(workers: int = 1): is_multiprocess = False _global_lock = ThreadLock() _shared_dicts = {} - _share_objects = {} _init_flags = {} direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") @@ -99,11 +93,7 @@ def try_initialize_namespace(namespace: str) -> bool: global _init_flags, _manager if _init_flags is None: - direct_log( - f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}", - level="ERROR", - ) - raise ValueError("Shared dictionaries not initialized") + raise ValueError("Try to create nanmespace before Shared-Data is initialized") if namespace not in _init_flags: _init_flags[namespace] = True @@ -113,43 +103,9 @@ def try_initialize_namespace(namespace: str) -> bool: return False -def _get_global_lock() -> LockType: - return _global_lock - - def get_storage_lock() -> LockType: """return storage lock for data consistency""" - return _get_global_lock() - - -def get_scan_lock() -> LockType: - """return scan_progress lock for data consistency""" - return get_storage_lock() - - -def get_namespace_object(namespace: str) -> Any: - """Get an object for specific namespace""" - - if _share_objects is None: - direct_log( - f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", - level="ERROR", - ) - raise ValueError("Shared dictionaries not initialized") - - lock = _get_global_lock() - with lock: - if namespace not in _share_objects: - if namespace not in _share_objects: - if is_multiprocess: - _share_objects[namespace] = _manager.Value("O", None) - else: - _share_objects[namespace] = None - direct_log( - f"Created namespace: {namespace}(type={type(_share_objects[namespace])})" - ) - - return _share_objects[namespace] + return _global_lock def get_namespace_data(namespace: str) -> Dict[str, Any]: @@ -161,7 +117,7 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: ) raise ValueError("Shared dictionaries not initialized") - lock = _get_global_lock() + lock = get_storage_lock() with lock: if namespace not in _shared_dicts: if is_multiprocess and _manager is not None: @@ -175,11 +131,6 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: return _shared_dicts[namespace] -def get_scan_progress() -> Dict[str, Any]: - """get storage space for document scanning progress data""" - return get_namespace_data("scan_progress") - - def finalize_share_data(): """ Release shared resources and clean up. @@ -195,7 +146,6 @@ def finalize_share_data(): is_multiprocess, \ _global_lock, \ _shared_dicts, \ - _share_objects, \ _init_flags, \ _initialized @@ -216,8 +166,6 @@ def finalize_share_data(): # Clear shared dictionaries first if _shared_dicts is not None: _shared_dicts.clear() - if _share_objects is not None: - _share_objects.clear() if _init_flags is not None: _init_flags.clear() @@ -234,7 +182,6 @@ def finalize_share_data(): _initialized = None is_multiprocess = None _shared_dicts = None - _share_objects = None _init_flags = None _global_lock = None From 3dcfa561d797f2637ab0677c94719816e32a4811 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 01:15:12 +0800 Subject: [PATCH 34/77] Remove debug logging --- lightrag/kg/shared_storage.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 681ef064..f7c2e909 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -62,15 +62,13 @@ def initialize_share_data(workers: int = 1): _manager = Manager() - # Force multi-process mode if workers > 1 if workers > 1: is_multiprocess = True _global_lock = _manager.Lock() - # Create shared dictionaries with manager _shared_dicts = _manager.dict() _init_flags = ( _manager.dict() - ) # Use shared dictionary to store initialization flags + ) direct_log( f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" ) @@ -124,9 +122,6 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: _shared_dicts[namespace] = _manager.dict() else: _shared_dicts[namespace] = {} - direct_log( - f"Created namespace: {{namespace}}({type(_shared_dicts[namespace])}) " - ) return _shared_dicts[namespace] From cd7648791a72af93efc04031c3fd7397550fe2ab Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 01:25:59 +0800 Subject: [PATCH 35/77] Fix linting --- lightrag/kg/faiss_impl.py | 11 +++-------- lightrag/kg/json_doc_status_impl.py | 4 +++- lightrag/kg/json_kv_impl.py | 4 +++- lightrag/kg/nano_vector_db_impl.py | 5 ++--- lightrag/kg/networkx_impl.py | 7 ++++--- lightrag/kg/shared_storage.py | 13 ++++++++----- 6 files changed, 23 insertions(+), 21 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index a3520653..d0ef6ed0 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -17,6 +17,7 @@ if not pm.is_installed("faiss"): import faiss # type: ignore from threading import Lock as ThreadLock + @final @dataclass class FaissVectorDBStorage(BaseVectorStorage): @@ -59,7 +60,6 @@ class FaissVectorDBStorage(BaseVectorStorage): with self._storage_lock: self._load_faiss_index() - def _get_index(self): """Check if the shtorage should be reloaded""" return self._index @@ -224,10 +224,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Searching relations for entity {entity_name}") relations = [] for fid, meta in self._id_to_meta.items(): - if ( - meta.get("src_id") == entity_name - or meta.get("tgt_id") == entity_name - ): + if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: relations.append(fid) logger.debug(f"Found {len(relations)} relations for {entity_name}") @@ -265,7 +262,7 @@ class FaissVectorDBStorage(BaseVectorStorage): new_id_to_meta[new_fid] = vec_meta with self._storage_lock: - # Re-init index + # Re-init index self._index = faiss.IndexFlatIP(self._dim) if vectors_to_keep: arr = np.array(vectors_to_keep, dtype=np.float32) @@ -273,7 +270,6 @@ class FaissVectorDBStorage(BaseVectorStorage): self._id_to_meta = new_id_to_meta - def _save_faiss_index(self): """ Save the current Faiss index + metadata to disk so it can persist across runs. @@ -290,7 +286,6 @@ class FaissVectorDBStorage(BaseVectorStorage): with open(self._meta_file, "w", encoding="utf-8") as f: json.dump(serializable_dict, f) - def _load_faiss_index(self): """ Load the Faiss index + metadata from disk if it exists, diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index b71cf618..05e6da37 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -84,7 +84,9 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: with self._storage_lock: - data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) write_json(data_dict, self._file_name) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index c5bff177..a4ce91a5 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -36,7 +36,9 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: with self._storage_lock: - data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) write_json(data_dict, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index b8fe573d..bbf991bf 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -18,6 +18,7 @@ if not pm.is_installed("nano-vectordb"): from nano_vectordb import NanoVectorDB from threading import Lock as ThreadLock + @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): @@ -148,9 +149,7 @@ class NanoVectorDBStorage(BaseVectorStorage): for dp in storage["data"] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name ] - logger.debug( - f"Found {len(relations)} relations for entity {entity_name}" - ) + logger.debug(f"Found {len(relations)} relations for entity {entity_name}") ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 1f14d5b0..ccf85855 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -19,6 +19,7 @@ import networkx as nx from graspologic import embed from threading import Lock as ThreadLock + @final @dataclass class NetworkXStorage(BaseGraphStorage): @@ -231,9 +232,9 @@ class NetworkXStorage(BaseGraphStorage): if len(subgraph.nodes()) > max_graph_nodes: origin_nodes = len(subgraph.nodes()) node_degrees = dict(subgraph.degree()) - top_nodes = sorted( - node_degrees.items(), key=lambda x: x[1], reverse=True - )[:max_graph_nodes] + top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ + :max_graph_nodes + ] top_node_ids = [node[0] for node in top_nodes] # Create new subgraph with only top nodes subgraph = subgraph.subgraph(top_node_ids) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index f7c2e909..19b1b1cb 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -26,6 +26,7 @@ _global_lock: Optional[LockType] = None _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized + def initialize_share_data(workers: int = 1): """ Initialize shared storage data for single or multi-process mode. @@ -66,9 +67,7 @@ def initialize_share_data(workers: int = 1): is_multiprocess = True _global_lock = _manager.Lock() _shared_dicts = _manager.dict() - _init_flags = ( - _manager.dict() - ) + _init_flags = _manager.dict() direct_log( f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" ) @@ -95,9 +94,13 @@ def try_initialize_namespace(namespace: str) -> bool: if namespace not in _init_flags: _init_flags[namespace] = True - direct_log(f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]") + direct_log( + f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + ) return True - direct_log(f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]") + direct_log( + f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]" + ) return False From b4bcd765991cc29cab35ef13900aabdbff194fb8 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 10:53:36 +0800 Subject: [PATCH 36/77] Remove useless scan progress tracking functionality and related code --- lightrag/api/routers/document_routes.py | 75 +------------------------ lightrag/lightrag.py | 13 ----- 2 files changed, 1 insertion(+), 87 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 2a6459fb..3bb36830 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -4,7 +4,6 @@ This module contains all document-related routes for the LightRAG API. import asyncio import logging -import os import aiofiles import shutil import traceback @@ -12,17 +11,12 @@ import pipmaster as pm from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Any -from ascii_colors import ASCIIColors from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, UploadFile from pydantic import BaseModel, Field, field_validator from lightrag import LightRAG from lightrag.base import DocProcessingStatus, DocStatus from ..utils_api import get_api_key_dependency -from lightrag.kg.shared_storage import ( - get_namespace_data, - get_storage_lock, -) router = APIRouter(prefix="/documents", tags=["documents"]) @@ -376,72 +370,19 @@ async def save_temp_file(input_dir: Path, file: UploadFile = File(...)) -> Path: async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): """Background task to scan and index documents""" - scan_progress = get_namespace_data("scan_progress") - scan_lock = get_storage_lock() - with scan_lock: - if scan_progress.get("is_scanning", False): - ASCIIColors.info("Skip document scanning(another scanning is active)") - return - scan_progress.update( - { - "is_scanning": True, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - } - ) - try: new_files = doc_manager.scan_directory_for_new_files() total_files = len(new_files) - scan_progress.update( - { - "current_file": "", - "total_files": total_files, - "indexed_count": 0, - "progress": 0, - } - ) - logging.info(f"Found {total_files} new files to index.") + for idx, file_path in enumerate(new_files): try: - progress = (idx / total_files * 100) if total_files > 0 else 0 - scan_progress.update( - { - "current_file": os.path.basename(file_path), - "indexed_count": idx, - "progress": progress, - } - ) - await pipeline_index_file(rag, file_path) - - progress = ((idx + 1) / total_files * 100) if total_files > 0 else 0 - scan_progress.update( - { - "current_file": os.path.basename(file_path), - "indexed_count": idx + 1, - "progress": progress, - } - ) - except Exception as e: logging.error(f"Error indexing file {file_path}: {str(e)}") except Exception as e: logging.error(f"Error during scanning process: {str(e)}") - finally: - scan_progress.update( - { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - } - ) def create_document_routes( @@ -465,20 +406,6 @@ def create_document_routes( background_tasks.add_task(run_scanning_process, rag, doc_manager) return {"status": "scanning_started"} - @router.get("/scan-progress") - async def get_scanning_progress(): - """ - Get the current progress of the document scanning process. - - Returns: - dict: A dictionary containing the current scanning progress information including: - - is_scanning: Whether a scan is currently in progress - - current_file: The file currently being processed - - indexed_count: Number of files indexed so far - - total_files: Total number of files to process - - progress: Percentage of completion - """ - return dict(get_namespace_data("scan_progress")) @router.post("/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir( diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0011fb6f..72f31315 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -276,20 +276,7 @@ class LightRAG: try_initialize_namespace, get_namespace_data, ) - initialize_share_data() - need_init = try_initialize_namespace("scan_progress") - scan_progress = get_namespace_data("scan_progress") - if need_init: - scan_progress.update( - { - "is_scanning": False, - "current_file": "", - "indexed_count": 0, - "total_files": 0, - "progress": 0, - } - ) if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") From feaa7ce69d2c2eac02e5dc91f5e62cd02a127ad4 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 10:54:43 +0800 Subject: [PATCH 37/77] Remove auto-scaling of workers based on CPU count in gunicorn config --- gunicorn_config.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index e89b8e12..9cdb18e8 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -1,6 +1,5 @@ # gunicorn_config.py import os -import multiprocessing from lightrag.kg.shared_storage import finalize_share_data from lightrag.api.utils_api import parse_args @@ -10,10 +9,6 @@ args = parse_args() # Determine worker count - from environment variable or command line arguments workers = int(os.getenv("WORKERS", args.workers)) -# If not specified, use CPU count * 2 + 1 (Gunicorn recommended configuration) -if workers <= 1: - workers = multiprocessing.cpu_count() * 2 + 1 - # Binding address bind = f"{os.getenv('HOST', args.host)}:{os.getenv('PORT', args.port)}" @@ -44,7 +39,7 @@ def on_starting(server): You can use this function to do more initialization tasks for all processes """ print("=" * 80) - print(f"GUNICORN MASTER PROCESS: on_starting jobs for all {workers} workers") + print(f"GUNICORN MASTER PROCESS: on_starting jobs for {workers} worker(s)") print(f"Process ID: {os.getpid()}") print("=" * 80) From b2da69b7f172d9f633095de8780de80071573701 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 11:52:42 +0800 Subject: [PATCH 38/77] Add pipeline status control for concurrent document indexing processes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add shared pipeline status namespace โ€ข Implement concurrent process control โ€ข Add request queuing for pending jobs --- lightrag/kg/shared_storage.py | 12 ++ lightrag/lightrag.py | 268 +++++++++++++++++++++------------- 2 files changed, 176 insertions(+), 104 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 19b1b1cb..9369376e 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -80,6 +80,18 @@ def initialize_share_data(workers: int = 1): # Mark as initialized _initialized = True + + # Initialize pipeline status for document indexing control + pipeline_namespace = get_namespace_data("pipeline_status") + pipeline_namespace.update({ + "busy": False, # Control concurrent processes + "job_name": "Default Job", # Current job name (indexing files/indexing texts) + "job_start": None, # Job start time + "docs": 0, # Total number of documents to be indexed + "batchs": 0, # Number of batches for processing documents + "cur_batch": 0, # Current processing batch + "request_pending": False, # Flag for pending request for processing + }) def try_initialize_namespace(namespace: str) -> bool: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 72f31315..b95da952 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -273,8 +273,6 @@ class LightRAG: from lightrag.kg.shared_storage import ( initialize_share_data, - try_initialize_namespace, - get_namespace_data, ) initialize_share_data() @@ -672,117 +670,179 @@ class LightRAG: 3. Process each chunk for entity and relation extraction 4. Update the document status """ - # 1. Get all pending, failed, and abnormally terminated processing documents. - # Run the asynchronous status retrievals in parallel using asyncio.gather - processing_docs, failed_docs, pending_docs = await asyncio.gather( - self.doc_status.get_docs_by_status(DocStatus.PROCESSING), - self.doc_status.get_docs_by_status(DocStatus.FAILED), - self.doc_status.get_docs_by_status(DocStatus.PENDING), - ) - - to_process_docs: dict[str, DocProcessingStatus] = {} - to_process_docs.update(processing_docs) - to_process_docs.update(failed_docs) - to_process_docs.update(pending_docs) - - if not to_process_docs: - logger.info("All documents have been processed or are duplicates") + from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock + + # Get pipeline status shared data and lock + pipeline_status = get_namespace_data("pipeline_status") + storage_lock = get_storage_lock() + + # Check if another process is already processing the queue + process_documents = False + with storage_lock: + if not pipeline_status.get("busy", False): + # No other process is busy, we can process documents + pipeline_status.update({ + "busy": True, + "job_name": "indexing files", + "job_start": datetime.now().isoformat(), + "docs": 0, + "batchs": 0, + "cur_batch": 0, + "request_pending": False # Clear any previous request + }) + process_documents = True + else: + # Another process is busy, just set request flag and return + pipeline_status["request_pending"] = True + logger.info("Another process is already processing the document queue. Request queued.") + + if not process_documents: return + + try: + # Process documents until no more documents or requests + while True: + # 1. Get all pending, failed, and abnormally terminated processing documents. + processing_docs, failed_docs, pending_docs = await asyncio.gather( + self.doc_status.get_docs_by_status(DocStatus.PROCESSING), + self.doc_status.get_docs_by_status(DocStatus.FAILED), + self.doc_status.get_docs_by_status(DocStatus.PENDING), + ) - # 2. split docs into chunks, insert chunks, update doc status - docs_batches = [ - list(to_process_docs.items())[i : i + self.max_parallel_insert] - for i in range(0, len(to_process_docs), self.max_parallel_insert) - ] + to_process_docs: dict[str, DocProcessingStatus] = {} + to_process_docs.update(processing_docs) + to_process_docs.update(failed_docs) + to_process_docs.update(pending_docs) - logger.info(f"Number of batches to process: {len(docs_batches)}.") + if not to_process_docs: + logger.info("All documents have been processed or are duplicates") + break - batches: list[Any] = [] - # 3. iterate over batches - for batch_idx, docs_batch in enumerate(docs_batches): + # Update pipeline status with document count (with lock) + with storage_lock: + pipeline_status["docs"] = len(to_process_docs) + + # 2. split docs into chunks, insert chunks, update doc status + docs_batches = [ + list(to_process_docs.items())[i : i + self.max_parallel_insert] + for i in range(0, len(to_process_docs), self.max_parallel_insert) + ] - async def batch( - batch_idx: int, - docs_batch: list[tuple[str, DocProcessingStatus]], - size_batch: int, - ) -> None: - logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.") - # 4. iterate over batch - for doc_id_processing_status in docs_batch: - doc_id, status_doc = doc_id_processing_status - # Generate chunks from document - chunks: dict[str, Any] = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_id, - } - for dp in self.chunking_func( - status_doc.content, - split_by_character, - split_by_character_only, - self.chunk_overlap_token_size, - self.chunk_token_size, - self.tiktoken_model_name, - ) - } - # Process document (text chunks and full docs) in parallel - tasks = [ - self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSING, - "updated_at": datetime.now().isoformat(), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, + # Update pipeline status with batch information (directly, as it's atomic) + pipeline_status.update({ + "batchs": len(docs_batches), + "cur_batch": 0 + }) + + logger.info(f"Number of batches to process: {len(docs_batches)}.") + + batches: list[Any] = [] + # 3. iterate over batches + for batch_idx, docs_batch in enumerate(docs_batches): + # Update current batch in pipeline status (directly, as it's atomic) + pipeline_status["cur_batch"] = batch_idx + 1 + + async def batch( + batch_idx: int, + docs_batch: list[tuple[str, DocProcessingStatus]], + size_batch: int, + ) -> None: + logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.") + # 4. iterate over batch + for doc_id_processing_status in docs_batch: + doc_id, status_doc = doc_id_processing_status + # Generate chunks from document + chunks: dict[str, Any] = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, } + for dp in self.chunking_func( + status_doc.content, + split_by_character, + split_by_character_only, + self.chunk_overlap_token_size, + self.chunk_token_size, + self.tiktoken_model_name, + ) } - ), - self.chunks_vdb.upsert(chunks), - self._process_entity_relation_graph(chunks), - self.full_docs.upsert( - {doc_id: {"content": status_doc.content}} - ), - self.text_chunks.upsert(chunks), - ] - try: - await asyncio.gather(*tasks) - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.PROCESSED, - "chunks_count": len(chunks), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - } - } - ) - except Exception as e: - logger.error(f"Failed to process document {doc_id}: {str(e)}") - await self.doc_status.upsert( - { - doc_id: { - "status": DocStatus.FAILED, - "error": str(e), - "content": status_doc.content, - "content_summary": status_doc.content_summary, - "content_length": status_doc.content_length, - "created_at": status_doc.created_at, - "updated_at": datetime.now().isoformat(), - } - } - ) - continue - logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.") + # Process document (text chunks and full docs) in parallel + tasks = [ + self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSING, + "updated_at": datetime.now().isoformat(), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + } + } + ), + self.chunks_vdb.upsert(chunks), + self._process_entity_relation_graph(chunks), + self.full_docs.upsert( + {doc_id: {"content": status_doc.content}} + ), + self.text_chunks.upsert(chunks), + ] + try: + await asyncio.gather(*tasks) + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } + } + ) + except Exception as e: + logger.error(f"Failed to process document {doc_id}: {str(e)}") + await self.doc_status.upsert( + { + doc_id: { + "status": DocStatus.FAILED, + "error": str(e), + "content": status_doc.content, + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + "updated_at": datetime.now().isoformat(), + } + } + ) + continue + logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.") - batches.append(batch(batch_idx, docs_batch, len(docs_batches))) + batches.append(batch(batch_idx, docs_batch, len(docs_batches))) - await asyncio.gather(*batches) - await self._insert_done() + await asyncio.gather(*batches) + await self._insert_done() + + # Check if there's a pending request to process more documents (with lock) + has_pending_request = False + with storage_lock: + has_pending_request = pipeline_status.get("request_pending", False) + if has_pending_request: + # Clear the request flag before checking for more documents + pipeline_status["request_pending"] = False + + if not has_pending_request: + break + + logger.info("Processing additional documents due to pending request") + + finally: + # Always reset busy status when done or if an exception occurs (with lock) + with storage_lock: + pipeline_status["busy"] = False + logger.info("Document processing pipeline completed") async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: From 04bd5413c991583c0c530834890e201cee8b0ba9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 12:21:50 +0800 Subject: [PATCH 39/77] Add API endpoint to retrieve document indexing pipeline status MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข GET /pipeline_status endpoint added โ€ข Returns current pipeline processing state --- lightrag/api/routers/document_routes.py | 29 +++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 3bb36830..d2ff91f1 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -653,6 +653,35 @@ def create_document_routes( logging.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + @router.get("/pipeline_status", dependencies=[Depends(optional_api_key)]) + async def get_pipeline_status(): + """ + Get the current status of the document indexing pipeline. + + This endpoint returns information about the current state of the document processing pipeline, + including whether it's busy, the current job name, when it started, how many documents + are being processed, how many batches there are, and which batch is currently being processed. + + Returns: + dict: A dictionary containing the pipeline status information + """ + try: + from lightrag.kg.shared_storage import get_namespace_data + pipeline_status = get_namespace_data("pipeline_status") + + # Convert to regular dict if it's a Manager.dict + status_dict = dict(pipeline_status) + + # Format the job_start time if it exists + if status_dict.get("job_start"): + status_dict["job_start"] = str(status_dict["job_start"]) + + return status_dict + except Exception as e: + logging.error(f"Error getting pipeline status: {str(e)}") + logging.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + @router.get("", dependencies=[Depends(optional_api_key)]) async def documents() -> DocsStatusesResponse: """ From b090a22be7553b947f182138e4fe4379dc915507 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 12:22:20 +0800 Subject: [PATCH 40/77] Add concurrency check for auto scan task to prevent duplicate scans MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add pipeline status check before scan โ€ข Add storage lock protection โ€ข Add latest_message to status tracking โ€ข Add helpful log message at startup --- lightrag/api/lightrag_server.py | 23 +++++++++++++++++++---- lightrag/kg/shared_storage.py | 1 + 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 155e22f5..9af1a90e 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -145,10 +145,25 @@ def create_app(args): # Auto scan documents if enabled if args.auto_scan_at_startup: - # Create background task - task = asyncio.create_task(run_scanning_process(rag, doc_manager)) - app.state.background_tasks.add(task) - task.add_done_callback(app.state.background_tasks.discard) + # Import necessary functions from shared_storage + from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock + + # Get pipeline status and lock + pipeline_status = get_namespace_data("pipeline_status") + storage_lock = get_storage_lock() + + # Check if a task is already running (with lock protection) + should_start_task = False + with storage_lock: + if not pipeline_status.get("busy", False): + should_start_task = True + # Only start the task if no other task is running + if should_start_task: + # Create background task + task = asyncio.create_task(run_scanning_process(rag, doc_manager)) + app.state.background_tasks.add(task) + task.add_done_callback(app.state.background_tasks.discard) + logger.info("Auto scan task started at startup.") ASCIIColors.green("\nServer is ready to accept connections! ๐Ÿš€\n") diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 9369376e..a4970321 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -91,6 +91,7 @@ def initialize_share_data(workers: int = 1): "batchs": 0, # Number of batches for processing documents "cur_batch": 0, # Current processing batch "request_pending": False, # Flag for pending request for processing + "latest_message": "" # Latest message from pipeline processing }) From 8cd45161f2df57a447266bf1f839b948c20c0388 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 13:53:40 +0800 Subject: [PATCH 41/77] feat: add history_messages to track pipeline processing progress MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add shared history_messages list โ€ข Track pipeline progress with messages --- lightrag/api/routers/document_routes.py | 4 ++ lightrag/kg/shared_storage.py | 7 +++- lightrag/lightrag.py | 52 +++++++++++++++++++++---- lightrag/operate.py | 32 ++++++++++----- 4 files changed, 77 insertions(+), 18 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index d2ff91f1..e274f4c4 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -672,6 +672,10 @@ def create_document_routes( # Convert to regular dict if it's a Manager.dict status_dict = dict(pipeline_status) + # Convert history_messages to a regular list if it's a Manager.list + if "history_messages" in status_dict: + status_dict["history_messages"] = list(status_dict["history_messages"]) + # Format the job_start time if it exists if status_dict.get("job_start"): status_dict["job_start"] = str(status_dict["job_start"]) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index a4970321..3a21dc5c 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -83,6 +83,10 @@ def initialize_share_data(workers: int = 1): # Initialize pipeline status for document indexing control pipeline_namespace = get_namespace_data("pipeline_status") + + # ๅˆ›ๅปบไธ€ไธชๅ…ฑไบซๅˆ—่กจๅฏน่ฑก็”จไบŽ history_messages + history_messages = _manager.list() if is_multiprocess else [] + pipeline_namespace.update({ "busy": False, # Control concurrent processes "job_name": "Default Job", # Current job name (indexing files/indexing texts) @@ -91,7 +95,8 @@ def initialize_share_data(workers: int = 1): "batchs": 0, # Number of batches for processing documents "cur_batch": 0, # Current processing batch "request_pending": False, # Flag for pending request for processing - "latest_message": "" # Latest message from pipeline processing + "latest_message": "", # Latest message from pipeline processing + "history_messages": history_messages, # ไฝฟ็”จๅ…ฑไบซๅˆ—่กจๅฏน่ฑก }) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b95da952..ee5bc397 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -681,6 +681,13 @@ class LightRAG: with storage_lock: if not pipeline_status.get("busy", False): # No other process is busy, we can process documents + # ่Žทๅ–ๅฝ“ๅ‰็š„ history_messages ๅˆ—่กจ + current_history = pipeline_status.get("history_messages", []) + + # ๆธ…็ฉบๅฝ“ๅ‰ๅˆ—่กจๅ†…ๅฎนไฝ†ไฟๆŒๅŒไธ€ไธชๅˆ—่กจๅฏน่ฑก + if hasattr(current_history, "clear"): + current_history.clear() + pipeline_status.update({ "busy": True, "job_name": "indexing files", @@ -688,7 +695,10 @@ class LightRAG: "docs": 0, "batchs": 0, "cur_batch": 0, - "request_pending": False # Clear any previous request + "request_pending": False, # Clear any previous request + "latest_message": "", + # ไฟๆŒไฝฟ็”จๅŒไธ€ไธชๅˆ—่กจๅฏน่ฑก + "history_messages": current_history, }) process_documents = True else: @@ -715,7 +725,10 @@ class LightRAG: to_process_docs.update(pending_docs) if not to_process_docs: - logger.info("All documents have been processed or are duplicates") + log_message = "All documents have been processed or are duplicates" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) break # Update pipeline status with document count (with lock) @@ -734,7 +747,10 @@ class LightRAG: "cur_batch": 0 }) - logger.info(f"Number of batches to process: {len(docs_batches)}.") + log_message = f"Number of batches to process: {len(docs_batches)}." + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) batches: list[Any] = [] # 3. iterate over batches @@ -747,7 +763,10 @@ class LightRAG: docs_batch: list[tuple[str, DocProcessingStatus]], size_batch: int, ) -> None: - logger.info(f"Start processing batch {batch_idx + 1} of {size_batch}.") + log_message = f"Start processing batch {batch_idx + 1} of {size_batch}." + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) # 4. iterate over batch for doc_id_processing_status in docs_batch: doc_id, status_doc = doc_id_processing_status @@ -818,7 +837,10 @@ class LightRAG: } ) continue - logger.info(f"Completed batch {batch_idx + 1} of {len(docs_batches)}.") + log_message = f"Completed batch {batch_idx + 1} of {len(docs_batches)}." + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) batches.append(batch(batch_idx, docs_batch, len(docs_batches))) @@ -836,13 +858,19 @@ class LightRAG: if not has_pending_request: break - logger.info("Processing additional documents due to pending request") + log_message = "Processing additional documents due to pending request" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) finally: # Always reset busy status when done or if an exception occurs (with lock) with storage_lock: pipeline_status["busy"] = False - logger.info("Document processing pipeline completed") + log_message = "Document processing pipeline completed" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: @@ -873,7 +901,15 @@ class LightRAG: if storage_inst is not None ] await asyncio.gather(*tasks) - logger.info("All Insert done") + + log_message = "All Insert done" + logger.info(log_message) + + # ่Žทๅ– pipeline_status ๅนถๆ›ดๆ–ฐ latest_message ๅ’Œ history_messages + from lightrag.kg.shared_storage import get_namespace_data + pipeline_status = get_namespace_data("pipeline_status") + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) def insert_custom_kg(self, custom_kg: dict[str, Any]) -> None: loop = always_get_an_event_loop() diff --git a/lightrag/operate.py b/lightrag/operate.py index e3f445bb..44a68655 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -336,6 +336,9 @@ async def extract_entities( global_config: dict[str, str], llm_response_cache: BaseKVStorage | None = None, ) -> None: + # ๅœจๅ‡ฝๆ•ฐๅผ€ๅง‹ๅค„ๆทปๅŠ ่Žทๅ– pipeline_status ็š„ไปฃ็  + from lightrag.kg.shared_storage import get_namespace_data + pipeline_status = get_namespace_data("pipeline_status") use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ @@ -496,9 +499,10 @@ async def extract_entities( processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) - logger.info( - f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)" - ) + log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) return dict(maybe_nodes), dict(maybe_edges) tasks = [_process_single_content(c) for c in ordered_chunks] @@ -527,17 +531,27 @@ async def extract_entities( ) if not (all_entities_data or all_relationships_data): - logger.info("Didn't extract any entities and relationships.") + log_message = "Didn't extract any entities and relationships." + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) return if not all_entities_data: - logger.info("Didn't extract any entities") + log_message = "Didn't extract any entities" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) if not all_relationships_data: - logger.info("Didn't extract any relationships") + log_message = "Didn't extract any relationships" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - logger.info( - f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)" - ) + log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) verbose_debug( f"New entities:{all_entities_data}, relationships:{all_relationships_data}" ) From 157ec862aec115e374775523acaf05b35071dda3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 14:57:25 +0800 Subject: [PATCH 42/77] Enhance logging system with file rotation and unified configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Unify logging across Gunicorn and Uvicorn โ€ข Add rotating file handlers --- gunicorn_config.py | 60 +++++++++++++++++++++++- lightrag/api/lightrag_server.py | 37 +++++++++++++-- lightrag/api/routers/document_routes.py | 62 ++++++++++++------------- lightrag/utils.py | 47 ++++++++++++++++--- run_with_gunicorn.py | 4 ++ 5 files changed, 166 insertions(+), 44 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index 9cdb18e8..fdb0140a 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -1,5 +1,8 @@ # gunicorn_config.py import os +import logging +from logging.config import dictConfig +from logging.handlers import RotatingFileHandler from lightrag.kg.shared_storage import finalize_share_data from lightrag.api.utils_api import parse_args @@ -27,11 +30,64 @@ if args.ssl: certfile = args.ssl_certfile keyfile = args.ssl_keyfile +# ่Žทๅ–ๆ—ฅๅฟ—ๆ–‡ไปถ่ทฏๅพ„ +log_file_path = os.path.abspath(os.path.join(os.getcwd(), "lightrag.log")) + # Logging configuration -errorlog = os.getenv("ERROR_LOG", "-") # '-' means stderr -accesslog = os.getenv("ACCESS_LOG", "-") # '-' means stderr +errorlog = os.getenv("ERROR_LOG", log_file_path) # ้ป˜่ฎคๅ†™ๅ…ฅๅˆฐ lightrag.log +accesslog = os.getenv("ACCESS_LOG", log_file_path) # ้ป˜่ฎคๅ†™ๅ…ฅๅˆฐ lightrag.log loglevel = os.getenv("LOG_LEVEL", "info") +# ้…็ฝฎๆ—ฅๅฟ—็ณป็ปŸ +logconfig_dict = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'standard': { + 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s' + }, + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'INFO', + 'formatter': 'standard', + 'stream': 'ext://sys.stdout' + }, + 'file': { + 'class': 'logging.handlers.RotatingFileHandler', + 'level': 'INFO', + 'formatter': 'standard', + 'filename': log_file_path, + 'maxBytes': 10485760, # 10MB + 'backupCount': 5, + 'encoding': 'utf8' + } + }, + 'loggers': { + 'lightrag': { + 'handlers': ['console', 'file'], + 'level': 'INFO', + 'propagate': False + }, + 'uvicorn': { + 'handlers': ['console', 'file'], + 'level': 'INFO', + 'propagate': False + }, + 'gunicorn': { + 'handlers': ['console', 'file'], + 'level': 'INFO', + 'propagate': False + }, + 'gunicorn.error': { + 'handlers': ['console', 'file'], + 'level': 'INFO', + 'propagate': False + } + } +} + def on_starting(server): """ diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 9af1a90e..66fcacde 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -438,13 +438,20 @@ def get_application(): def configure_logging(): """Configure logging for both uvicorn and lightrag""" + # Check if running under Gunicorn + if "GUNICORN_CMD_ARGS" in os.environ: + # If started with Gunicorn, return directly as Gunicorn will handle logging + return + # Reset any existing handlers to ensure clean configuration - for logger_name in ["uvicorn.access", "lightrag"]: + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: logger = logging.getLogger(logger_name) logger.handlers = [] logger.filters = [] # Configure basic logging + log_file_path = os.path.abspath(os.path.join(os.getcwd(), "lightrag.log")) + logging.config.dictConfig( { "version": 1, @@ -453,23 +460,45 @@ def configure_logging(): "default": { "format": "%(levelname)s: %(message)s", }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, }, "handlers": { - "default": { + "console": { "formatter": "default", "class": "logging.StreamHandler", "stream": "ext://sys.stderr", }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": 10*1024*1024, # 10MB + "backupCount": 5, + "encoding": "utf-8", + }, }, "loggers": { + # Configure all uvicorn related loggers + "uvicorn": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, "uvicorn.access": { - "handlers": ["default"], + "handlers": ["console", "file"], "level": "INFO", "propagate": False, "filters": ["path_filter"], }, + "uvicorn.error": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, "lightrag": { - "handlers": ["default"], + "handlers": ["console", "file"], "level": "INFO", "propagate": False, "filters": ["path_filter"], diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index e274f4c4..3126b8ce 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -3,7 +3,7 @@ This module contains all document-related routes for the LightRAG API. """ import asyncio -import logging +from lightrag.utils import logger import aiofiles import shutil import traceback @@ -147,7 +147,7 @@ class DocumentManager: """Scan input directory for new files""" new_files = [] for ext in self.supported_extensions: - logging.debug(f"Scanning for {ext} files in {self.input_dir}") + logger.debug(f"Scanning for {ext} files in {self.input_dir}") for file_path in self.input_dir.rglob(f"*{ext}"): if file_path not in self.indexed_files: new_files.append(file_path) @@ -266,7 +266,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ) content += "\n" case _: - logging.error( + logger.error( f"Unsupported file type: {file_path.name} (extension {ext})" ) return False @@ -274,20 +274,20 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: # Insert into the RAG queue if content: await rag.apipeline_enqueue_documents(content) - logging.info(f"Successfully fetched and enqueued file: {file_path.name}") + logger.info(f"Successfully fetched and enqueued file: {file_path.name}") return True else: - logging.error(f"No content could be extracted from file: {file_path.name}") + logger.error(f"No content could be extracted from file: {file_path.name}") except Exception as e: - logging.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error processing or enqueueing file {file_path.name}: {str(e)}") + logger.error(traceback.format_exc()) finally: if file_path.name.startswith(temp_prefix): try: file_path.unlink() except Exception as e: - logging.error(f"Error deleting file {file_path}: {str(e)}") + logger.error(f"Error deleting file {file_path}: {str(e)}") return False @@ -303,8 +303,8 @@ async def pipeline_index_file(rag: LightRAG, file_path: Path): await rag.apipeline_process_enqueue_documents() except Exception as e: - logging.error(f"Error indexing file {file_path.name}: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error indexing file {file_path.name}: {str(e)}") + logger.error(traceback.format_exc()) async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): @@ -328,8 +328,8 @@ async def pipeline_index_files(rag: LightRAG, file_paths: List[Path]): if enqueued: await rag.apipeline_process_enqueue_documents() except Exception as e: - logging.error(f"Error indexing files: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error indexing files: {str(e)}") + logger.error(traceback.format_exc()) async def pipeline_index_texts(rag: LightRAG, texts: List[str]): @@ -373,16 +373,16 @@ async def run_scanning_process(rag: LightRAG, doc_manager: DocumentManager): try: new_files = doc_manager.scan_directory_for_new_files() total_files = len(new_files) - logging.info(f"Found {total_files} new files to index.") + logger.info(f"Found {total_files} new files to index.") for idx, file_path in enumerate(new_files): try: await pipeline_index_file(rag, file_path) except Exception as e: - logging.error(f"Error indexing file {file_path}: {str(e)}") + logger.error(f"Error indexing file {file_path}: {str(e)}") except Exception as e: - logging.error(f"Error during scanning process: {str(e)}") + logger.error(f"Error during scanning process: {str(e)}") def create_document_routes( @@ -447,8 +447,8 @@ def create_document_routes( message=f"File '{file.filename}' uploaded successfully. Processing will continue in background.", ) except Exception as e: - logging.error(f"Error /documents/upload: {file.filename}: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/upload: {file.filename}: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -480,8 +480,8 @@ def create_document_routes( message="Text successfully received. Processing will continue in background.", ) except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/text: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -515,8 +515,8 @@ def create_document_routes( message="Text successfully received. Processing will continue in background.", ) except Exception as e: - logging.error(f"Error /documents/text: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/text: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -558,8 +558,8 @@ def create_document_routes( message=f"File '{file.filename}' saved successfully. Processing will continue in background.", ) except Exception as e: - logging.error(f"Error /documents/file: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/file: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.post( @@ -621,8 +621,8 @@ def create_document_routes( return InsertResponse(status=status, message=status_message) except Exception as e: - logging.error(f"Error /documents/batch: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error /documents/batch: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.delete( @@ -649,8 +649,8 @@ def create_document_routes( status="success", message="All documents cleared successfully" ) except Exception as e: - logging.error(f"Error DELETE /documents: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error DELETE /documents: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.get("/pipeline_status", dependencies=[Depends(optional_api_key)]) @@ -682,8 +682,8 @@ def create_document_routes( return status_dict except Exception as e: - logging.error(f"Error getting pipeline status: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error getting pipeline status: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.get("", dependencies=[Depends(optional_api_key)]) @@ -739,8 +739,8 @@ def create_document_routes( ) return response except Exception as e: - logging.error(f"Error GET /documents: {str(e)}") - logging.error(traceback.format_exc()) + logger.error(f"Error GET /documents: {str(e)}") + logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) return router diff --git a/lightrag/utils.py b/lightrag/utils.py index a6265048..3ec96112 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -75,18 +75,51 @@ def set_logger(log_file: str, level: int = logging.DEBUG): log_file: Path to the log file level: Logging level (e.g. logging.DEBUG, logging.INFO) """ + # ่ฎพ็ฝฎๆ—ฅๅฟ—็บงๅˆซ logger.setLevel(level) - - file_handler = logging.FileHandler(log_file, encoding="utf-8") - file_handler.setLevel(level) - + + # ็กฎไฟไฝฟ็”จ็ปๅฏน่ทฏๅพ„ + log_file = os.path.abspath(log_file) + + # ๅˆ›ๅปบๆ ผๅผๅŒ–ๅ™จ formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) - file_handler.setFormatter(formatter) - - if not logger.handlers: + + # ๆฃ€ๆŸฅๆ˜ฏๅฆๅทฒ็ปๆœ‰ๆ–‡ไปถๅค„็†ๅ™จ + has_file_handler = False + has_console_handler = False + + # ๆฃ€ๆŸฅ็Žฐๆœ‰ๅค„็†ๅ™จ + for handler in logger.handlers: + if isinstance(handler, logging.FileHandler): + has_file_handler = True + elif isinstance(handler, logging.StreamHandler) and not isinstance(handler, logging.FileHandler): + has_console_handler = True + + # ๅฆ‚ๆžœๆฒกๆœ‰ๆ–‡ไปถๅค„็†ๅ™จ๏ผŒๆทปๅŠ ไธ€ไธช + if not has_file_handler: + # ไฝฟ็”จ RotatingFileHandler ไปฃๆ›ฟ FileHandler + from logging.handlers import RotatingFileHandler + file_handler = RotatingFileHandler( + log_file, + maxBytes=10*1024*1024, # 10MB + backupCount=5, + encoding="utf-8" + ) + file_handler.setLevel(level) + file_handler.setFormatter(formatter) logger.addHandler(file_handler) + + # ๅฆ‚ๆžœๆฒกๆœ‰ๆŽงๅˆถๅฐๅค„็†ๅ™จ๏ผŒๆทปๅŠ ไธ€ไธช + if not has_console_handler: + console_handler = logging.StreamHandler() + console_handler.setLevel(level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # ่ฎพ็ฝฎๆ—ฅๅฟ—ไผ ๆ’ญไธบ False๏ผŒ้ฟๅ…้‡ๅค่พ“ๅ‡บ + logger.propagate = False class UnlimitedSemaphore: diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index 705cb88f..7b98cb1c 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -157,6 +157,10 @@ def main(): value = getattr(self.config_module, key) if callable(value): self.cfg.set(key, value) + + # ็กฎไฟๆญฃ็กฎๅŠ ่ฝฝ logconfig_dict + if hasattr(self.config_module, 'logconfig_dict'): + self.cfg.set('logconfig_dict', getattr(self.config_module, 'logconfig_dict')) # Override with command line arguments if provided if gunicorn_args.workers: From 81f6f6e343245cb7b61a69b601bcc0da949aed2a Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 16:07:33 +0800 Subject: [PATCH 43/77] Fix lightrag logger initailization problem, fix gunicorn acccess log missing --- gunicorn_config.py | 10 +++++----- lightrag/api/lightrag_server.py | 11 ++--------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index fdb0140a..3139d21e 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -70,11 +70,6 @@ logconfig_dict = { 'level': 'INFO', 'propagate': False }, - 'uvicorn': { - 'handlers': ['console', 'file'], - 'level': 'INFO', - 'propagate': False - }, 'gunicorn': { 'handlers': ['console', 'file'], 'level': 'INFO', @@ -84,6 +79,11 @@ logconfig_dict = { 'handlers': ['console', 'file'], 'level': 'INFO', 'propagate': False + }, + 'gunicorn.access': { + 'handlers': ['console', 'file'], + 'level': 'INFO', + 'propagate': False } } } diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 66fcacde..33a03cbc 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -39,13 +39,10 @@ from .routers.query_routes import create_query_routes from .routers.graph_routes import create_graph_routes from .routers.ollama_api import OllamaAPI -from lightrag.utils import logger as utils_logger +from lightrag.utils import logger, set_verbose_debug # Load environment variables -try: - load_dotenv(override=True) -except Exception as e: - utils_logger.warning(f"Failed to load .env file: {e}") +load_dotenv(override=True) # Initialize config parser config = configparser.ConfigParser() @@ -88,10 +85,6 @@ class LightragPathFilter(logging.Filter): def create_app(args): - # Initialize verbose debug setting - # Can not use the logger at the top of this module when workers > 1 - from lightrag.utils import set_verbose_debug, logger - # Setup logging logger.setLevel(getattr(logging, args.log_level)) set_verbose_debug(args.verbose) From f588cdc5df778f6cd425fa3ed88f5c4b056337a4 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 16:50:54 +0800 Subject: [PATCH 44/77] Optimize logging config & worker handling for different server modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Separate logging config for uvicorn/gunicorn โ€ข Force workers=1 in uvicorn mode โ€ข Add warning for worker count in uvicorn --- lightrag/api/lightrag_server.py | 15 ++++----------- lightrag/api/utils_api.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 33a03cbc..f130a0fa 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -414,9 +414,6 @@ def create_app(args): def get_application(): """Factory function for creating the FastAPI application""" - # Configure logging for this worker process - configure_logging() - # Get args from environment variable args_json = os.environ.get("LIGHTRAG_ARGS") if not args_json: @@ -430,11 +427,7 @@ def get_application(): def configure_logging(): - """Configure logging for both uvicorn and lightrag""" - # Check if running under Gunicorn - if "GUNICORN_CMD_ARGS" in os.environ: - # If started with Gunicorn, return directly as Gunicorn will handle logging - return + """Configure logging for uvicorn startup""" # Reset any existing handlers to ensure clean configuration for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: @@ -517,13 +510,13 @@ def main(): freeze_support() + # Configure logging before parsing args + configure_logging() + args = parse_args() # Save args to environment variable for child processes os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args)) - # Configure logging before starting uvicorn - configure_logging() - display_splash_screen(args) # Create application instance directly instead of using factory function diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index c494101c..f63e9c92 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -6,6 +6,7 @@ import os import argparse from typing import Optional import sys +import logging from ascii_colors import ASCIIColors from lightrag.api import __api_version__ from fastapi import HTTPException, Security @@ -286,6 +287,16 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() + # Check if running under uvicorn mode (not Gunicorn) + is_uvicorn_mode = "GUNICORN_CMD_ARGS" not in os.environ + + # If in uvicorn mode and workers > 1, force it to 1 and log warning + if is_uvicorn_mode and args.workers > 1: + original_workers = args.workers + args.workers = 1 + # Log warning directly here + logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1") + # convert relative path to absolute path args.working_dir = os.path.abspath(args.working_dir) args.input_dir = os.path.abspath(args.input_dir) From ff549a3a9c28f74259fd8605bf1b5ae0f78f3341 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 17:45:40 +0800 Subject: [PATCH 45/77] Update Gunicorn config with logging filters and worker-specific configurations --- gunicorn_config.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index 3139d21e..daab1955 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -1,10 +1,9 @@ # gunicorn_config.py import os import logging -from logging.config import dictConfig -from logging.handlers import RotatingFileHandler from lightrag.kg.shared_storage import finalize_share_data from lightrag.api.utils_api import parse_args +from lightrag.api.lightrag_server import LightragPathFilter # Parse command line arguments args = parse_args() @@ -64,6 +63,11 @@ logconfig_dict = { 'encoding': 'utf8' } }, + 'filters': { + 'path_filter': { + '()': 'lightrag.api.lightrag_server.LightragPathFilter', + }, + }, 'loggers': { 'lightrag': { 'handlers': ['console', 'file'], @@ -83,7 +87,8 @@ logconfig_dict = { 'gunicorn.access': { 'handlers': ['console', 'file'], 'level': 'INFO', - 'propagate': False + 'propagate': False, + 'filters': ['path_filter'] } } } @@ -131,3 +136,20 @@ def on_exit(server): print("=" * 80) print("Gunicorn shutdown complete") print("=" * 80) + + +def post_fork(server, worker): + """ + Executed after a worker has been forked. + This is a good place to set up worker-specific configurations. + """ + # Disable uvicorn.error logger in worker processes + uvicorn_error_logger = logging.getLogger("uvicorn.error") + uvicorn_error_logger.setLevel(logging.CRITICAL) + uvicorn_error_logger.handlers = [] + uvicorn_error_logger.propagate = False + + # Add log filter to uvicorn.access handler in worker processes + uvicorn_access_logger = logging.getLogger("uvicorn.access") + path_filter = LightragPathFilter() + uvicorn_access_logger.addFilter(path_filter) From c37b1e8aa7c7404075cde2a1261d0c0c903ae50f Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 20:41:11 +0800 Subject: [PATCH 46/77] Align Gunicorn configuration with Uvicorn - centralize config in gunicorn_config.py - fix log level handling in Gunicorn --- gunicorn_config.py | 39 ++++++-------- lightrag/api/lightrag_server.py | 20 ++------ lightrag/api/utils_api.py | 8 +-- run_with_gunicorn.py | 90 ++++++++------------------------- 4 files changed, 47 insertions(+), 110 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index daab1955..0ca6f9d9 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -2,17 +2,17 @@ import os import logging from lightrag.kg.shared_storage import finalize_share_data -from lightrag.api.utils_api import parse_args from lightrag.api.lightrag_server import LightragPathFilter -# Parse command line arguments -args = parse_args() +# ่Žทๅ–ๆ—ฅๅฟ—ๆ–‡ไปถ่ทฏๅพ„ +log_file_path = os.path.abspath(os.path.join(os.getcwd(), "lightrag.log")) -# Determine worker count - from environment variable or command line arguments -workers = int(os.getenv("WORKERS", args.workers)) - -# Binding address -bind = f"{os.getenv('HOST', args.host)}:{os.getenv('PORT', args.port)}" +# These variables will be set by run_with_gunicorn.py +workers = None +bind = None +loglevel = None +certfile = None +keyfile = None # Enable preload_app option preload_app = True @@ -24,18 +24,9 @@ worker_class = "uvicorn.workers.UvicornWorker" timeout = int(os.getenv("TIMEOUT", 120)) keepalive = 5 -# Optional SSL configuration -if args.ssl: - certfile = args.ssl_certfile - keyfile = args.ssl_keyfile - -# ่Žทๅ–ๆ—ฅๅฟ—ๆ–‡ไปถ่ทฏๅพ„ -log_file_path = os.path.abspath(os.path.join(os.getcwd(), "lightrag.log")) - # Logging configuration errorlog = os.getenv("ERROR_LOG", log_file_path) # ้ป˜่ฎคๅ†™ๅ…ฅๅˆฐ lightrag.log accesslog = os.getenv("ACCESS_LOG", log_file_path) # ้ป˜่ฎคๅ†™ๅ…ฅๅˆฐ lightrag.log -loglevel = os.getenv("LOG_LEVEL", "info") # ้…็ฝฎๆ—ฅๅฟ—็ณป็ปŸ logconfig_dict = { @@ -49,13 +40,11 @@ logconfig_dict = { 'handlers': { 'console': { 'class': 'logging.StreamHandler', - 'level': 'INFO', 'formatter': 'standard', 'stream': 'ext://sys.stdout' }, 'file': { 'class': 'logging.handlers.RotatingFileHandler', - 'level': 'INFO', 'formatter': 'standard', 'filename': log_file_path, 'maxBytes': 10485760, # 10MB @@ -71,22 +60,22 @@ logconfig_dict = { 'loggers': { 'lightrag': { 'handlers': ['console', 'file'], - 'level': 'INFO', + 'level': loglevel.upper() if loglevel else 'INFO', 'propagate': False }, 'gunicorn': { 'handlers': ['console', 'file'], - 'level': 'INFO', + 'level': loglevel.upper() if loglevel else 'INFO', 'propagate': False }, 'gunicorn.error': { 'handlers': ['console', 'file'], - 'level': 'INFO', + 'level': loglevel.upper() if loglevel else 'INFO', 'propagate': False }, 'gunicorn.access': { 'handlers': ['console', 'file'], - 'level': 'INFO', + 'level': loglevel.upper() if loglevel else 'INFO', 'propagate': False, 'filters': ['path_filter'] } @@ -143,6 +132,10 @@ def post_fork(server, worker): Executed after a worker has been forked. This is a good place to set up worker-specific configurations. """ + # Set lightrag logger level in worker processes using gunicorn's loglevel + from lightrag.utils import logger + logger.setLevel(loglevel.upper()) + # Disable uvicorn.error logger in worker processes uvicorn_error_logger = logging.getLogger("uvicorn.error") uvicorn_error_logger.setLevel(logging.CRITICAL) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index f130a0fa..8f7a6781 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -86,7 +86,7 @@ class LightragPathFilter(logging.Filter): def create_app(args): # Setup logging - logger.setLevel(getattr(logging, args.log_level)) + logger.setLevel(args.log_level) set_verbose_debug(args.verbose) # Verify that bindings are correctly setup @@ -412,17 +412,10 @@ def create_app(args): return app -def get_application(): +def get_application(args=None): """Factory function for creating the FastAPI application""" - # Get args from environment variable - args_json = os.environ.get("LIGHTRAG_ARGS") - if not args_json: - args = parse_args() # Fallback to parsing args if env var not set - else: - import types - - args = types.SimpleNamespace(**json.loads(args_json)) - + if args is None: + args = parse_args() return create_app(args) @@ -513,10 +506,7 @@ def main(): # Configure logging before parsing args configure_logging() - args = parse_args() - # Save args to environment variable for child processes - os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args)) - + args = parse_args(is_uvicorn_mode=True) display_splash_screen(args) # Create application instance directly instead of using factory function diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index f63e9c92..4b5e0a28 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -111,10 +111,13 @@ def get_env_value(env_key: str, default: any, value_type: type = str) -> any: return default -def parse_args() -> argparse.Namespace: +def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: """ Parse command line arguments with environment variable fallback + Args: + is_uvicorn_mode: Whether running under uvicorn mode + Returns: argparse.Namespace: Parsed arguments """ @@ -287,9 +290,6 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() - # Check if running under uvicorn mode (not Gunicorn) - is_uvicorn_mode = "GUNICORN_CMD_ARGS" not in os.environ - # If in uvicorn mode and workers > 1, force it to 1 and log warning if is_uvicorn_mode and args.workers > 1: original_workers = args.workers diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index 7b98cb1c..de2b21b6 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -30,45 +30,10 @@ def main(): # Register signal handlers for graceful shutdown signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGTERM, signal_handler) # kill command - # Create a parser to handle Gunicorn-specific parameters - parser = argparse.ArgumentParser(description="Start LightRAG server with Gunicorn") - parser.add_argument( - "--workers", - type=int, - help="Number of worker processes (overrides the default or config.ini setting)", - ) - parser.add_argument( - "--timeout", type=int, help="Worker timeout in seconds (default: 120)" - ) - parser.add_argument( - "--log-level", - choices=["debug", "info", "warning", "error", "critical"], - help="Gunicorn log level", - ) - # Parse Gunicorn-specific arguments - gunicorn_args, remaining_args = parser.parse_known_args() + # Parse all arguments using parse_args + args = parse_args(is_uvicorn_mode=False) - # Pass remaining arguments to LightRAG's parse_args - sys.argv = [sys.argv[0]] + remaining_args - args = parse_args() - - # If workers specified, override args value - if gunicorn_args.workers: - args.workers = gunicorn_args.workers - os.environ["WORKERS"] = str(gunicorn_args.workers) - - # If timeout specified, set environment variable - if gunicorn_args.timeout: - os.environ["TIMEOUT"] = str(gunicorn_args.timeout) - - # If log-level specified, set environment variable - if gunicorn_args.log_level: - os.environ["LOG_LEVEL"] = gunicorn_args.log_level - - # Save all LightRAG args to environment variable for worker processes - # This is the key step for passing arguments to lightrag_server.py - os.environ["LIGHTRAG_ARGS"] = json.dumps(vars(args)) # Display startup information display_splash_screen(args) @@ -83,11 +48,6 @@ def main(): print(f"Workers setting: {args.workers}") print("=" * 80 + "\n") - # Start application with Gunicorn using direct Python API - # Ensure WORKERS environment variable is set before importing gunicorn_config - if args.workers > 1: - os.environ["WORKERS"] = str(args.workers) - # Import Gunicorn's StandaloneApplication from gunicorn.app.base import BaseApplication @@ -136,51 +96,45 @@ def main(): "child_exit", } - # Import the gunicorn_config module directly - import importlib.util + # Import and configure the gunicorn_config module + import gunicorn_config - spec = importlib.util.spec_from_file_location( - "gunicorn_config", "gunicorn_config.py" - ) - self.config_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(self.config_module) + # Set configuration variables in gunicorn_config + gunicorn_config.workers = int(os.getenv("WORKERS", args.workers)) + gunicorn_config.bind = f"{os.getenv('HOST', args.host)}:{os.getenv('PORT', args.port)}" + gunicorn_config.loglevel = args.log_level.lower() if args.log_level else os.getenv("LOG_LEVEL", "info") + + # Set SSL configuration if enabled + if args.ssl: + gunicorn_config.certfile = args.ssl_certfile + gunicorn_config.keyfile = args.ssl_keyfile - # Set configuration options - for key in dir(self.config_module): + # Set configuration options from the module + for key in dir(gunicorn_config): if key in valid_options: - value = getattr(self.config_module, key) - # Skip functions like on_starting - if not callable(value): + value = getattr(gunicorn_config, key) + # Skip functions like on_starting and None values + if not callable(value) and value is not None: self.cfg.set(key, value) # Set special hooks elif key in special_hooks: - value = getattr(self.config_module, key) + value = getattr(gunicorn_config, key) if callable(value): self.cfg.set(key, value) # ็กฎไฟๆญฃ็กฎๅŠ ่ฝฝ logconfig_dict - if hasattr(self.config_module, 'logconfig_dict'): - self.cfg.set('logconfig_dict', getattr(self.config_module, 'logconfig_dict')) - - # Override with command line arguments if provided - if gunicorn_args.workers: - self.cfg.set("workers", gunicorn_args.workers) - if gunicorn_args.timeout: - self.cfg.set("timeout", gunicorn_args.timeout) - if gunicorn_args.log_level: - self.cfg.set("loglevel", gunicorn_args.log_level) + if hasattr(gunicorn_config, 'logconfig_dict'): + self.cfg.set('logconfig_dict', getattr(gunicorn_config, 'logconfig_dict')) def load(self): # Import the application from lightrag.api.lightrag_server import get_application - return get_application() + return get_application(args) # Create the application app = GunicornApp("") - # Directly call initialize_share_data with the correct workers value - # Force workers to be an integer and greater than 1 for multi-process mode workers_count = int(args.workers) if workers_count > 1: From c973498c344fa53c6d26583563ed539bad3e1442 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 21:35:04 +0800 Subject: [PATCH 47/77] Fix linting --- gunicorn_config.py | 90 ++++++++++++------------- lightrag/api/lightrag_server.py | 16 +++-- lightrag/api/routers/document_routes.py | 14 ++-- lightrag/api/utils_api.py | 4 +- lightrag/lightrag.py | 82 +++++++++++----------- lightrag/operate.py | 3 +- lightrag/utils.py | 32 ++++----- run_with_gunicorn.py | 25 ++++--- 8 files changed, 136 insertions(+), 130 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index 0ca6f9d9..810fc721 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -28,58 +28,55 @@ keepalive = 5 errorlog = os.getenv("ERROR_LOG", log_file_path) # ้ป˜่ฎคๅ†™ๅ…ฅๅˆฐ lightrag.log accesslog = os.getenv("ACCESS_LOG", log_file_path) # ้ป˜่ฎคๅ†™ๅ…ฅๅˆฐ lightrag.log -# ้…็ฝฎๆ—ฅๅฟ—็ณป็ปŸ logconfig_dict = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'standard': { - 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s' + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"}, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "standard", + "stream": "ext://sys.stdout", + }, + "file": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "standard", + "filename": log_file_path, + "maxBytes": 10485760, # 10MB + "backupCount": 5, + "encoding": "utf8", }, }, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - 'formatter': 'standard', - 'stream': 'ext://sys.stdout' - }, - 'file': { - 'class': 'logging.handlers.RotatingFileHandler', - 'formatter': 'standard', - 'filename': log_file_path, - 'maxBytes': 10485760, # 10MB - 'backupCount': 5, - 'encoding': 'utf8' - } - }, - 'filters': { - 'path_filter': { - '()': 'lightrag.api.lightrag_server.LightragPathFilter', + "filters": { + "path_filter": { + "()": "lightrag.api.lightrag_server.LightragPathFilter", }, }, - 'loggers': { - 'lightrag': { - 'handlers': ['console', 'file'], - 'level': loglevel.upper() if loglevel else 'INFO', - 'propagate': False + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": loglevel.upper() if loglevel else "INFO", + "propagate": False, }, - 'gunicorn': { - 'handlers': ['console', 'file'], - 'level': loglevel.upper() if loglevel else 'INFO', - 'propagate': False + "gunicorn": { + "handlers": ["console", "file"], + "level": loglevel.upper() if loglevel else "INFO", + "propagate": False, }, - 'gunicorn.error': { - 'handlers': ['console', 'file'], - 'level': loglevel.upper() if loglevel else 'INFO', - 'propagate': False + "gunicorn.error": { + "handlers": ["console", "file"], + "level": loglevel.upper() if loglevel else "INFO", + "propagate": False, }, - 'gunicorn.access': { - 'handlers': ['console', 'file'], - 'level': loglevel.upper() if loglevel else 'INFO', - 'propagate': False, - 'filters': ['path_filter'] - } - } + "gunicorn.access": { + "handlers": ["console", "file"], + "level": loglevel.upper() if loglevel else "INFO", + "propagate": False, + "filters": ["path_filter"], + }, + }, } @@ -134,14 +131,15 @@ def post_fork(server, worker): """ # Set lightrag logger level in worker processes using gunicorn's loglevel from lightrag.utils import logger + logger.setLevel(loglevel.upper()) - + # Disable uvicorn.error logger in worker processes uvicorn_error_logger = logging.getLogger("uvicorn.error") uvicorn_error_logger.setLevel(logging.CRITICAL) uvicorn_error_logger.handlers = [] uvicorn_error_logger.propagate = False - + # Add log filter to uvicorn.access handler in worker processes uvicorn_access_logger = logging.getLogger("uvicorn.access") path_filter = LightragPathFilter() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8f7a6781..d00d39d1 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -9,7 +9,6 @@ from fastapi import ( from fastapi.responses import FileResponse import asyncio import os -import json import logging import logging.config import uvicorn @@ -139,17 +138,20 @@ def create_app(args): # Auto scan documents if enabled if args.auto_scan_at_startup: # Import necessary functions from shared_storage - from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock - + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_storage_lock, + ) + # Get pipeline status and lock pipeline_status = get_namespace_data("pipeline_status") storage_lock = get_storage_lock() - + # Check if a task is already running (with lock protection) should_start_task = False with storage_lock: if not pipeline_status.get("busy", False): - should_start_task = True + should_start_task = True # Only start the task if no other task is running if should_start_task: # Create background task @@ -430,7 +432,7 @@ def configure_logging(): # Configure basic logging log_file_path = os.path.abspath(os.path.join(os.getcwd(), "lightrag.log")) - + logging.config.dictConfig( { "version": 1, @@ -453,7 +455,7 @@ def configure_logging(): "formatter": "detailed", "class": "logging.handlers.RotatingFileHandler", "filename": log_file_path, - "maxBytes": 10*1024*1024, # 10MB + "maxBytes": 10 * 1024 * 1024, # 10MB "backupCount": 5, "encoding": "utf-8", }, diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 3126b8ce..3fdbdf9e 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -406,7 +406,6 @@ def create_document_routes( background_tasks.add_task(run_scanning_process, rag, doc_manager) return {"status": "scanning_started"} - @router.post("/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir( background_tasks: BackgroundTasks, file: UploadFile = File(...) @@ -657,29 +656,30 @@ def create_document_routes( async def get_pipeline_status(): """ Get the current status of the document indexing pipeline. - + This endpoint returns information about the current state of the document processing pipeline, including whether it's busy, the current job name, when it started, how many documents are being processed, how many batches there are, and which batch is currently being processed. - + Returns: dict: A dictionary containing the pipeline status information """ try: from lightrag.kg.shared_storage import get_namespace_data + pipeline_status = get_namespace_data("pipeline_status") - + # Convert to regular dict if it's a Manager.dict status_dict = dict(pipeline_status) - + # Convert history_messages to a regular list if it's a Manager.list if "history_messages" in status_dict: status_dict["history_messages"] = list(status_dict["history_messages"]) - + # Format the job_start time if it exists if status_dict.get("job_start"): status_dict["job_start"] = str(status_dict["job_start"]) - + return status_dict except Exception as e: logger.error(f"Error getting pipeline status: {str(e)}") diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 4b5e0a28..ed1250d4 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -295,7 +295,9 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: original_workers = args.workers args.workers = 1 # Log warning directly here - logging.warning(f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1") + logging.warning( + f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1" + ) # convert relative path to absolute path args.working_dir = os.path.abspath(args.working_dir) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ee5bc397..2dfcae44 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -274,6 +274,7 @@ class LightRAG: from lightrag.kg.shared_storage import ( initialize_share_data, ) + initialize_share_data() if not os.path.exists(self.working_dir): @@ -671,44 +672,45 @@ class LightRAG: 4. Update the document status """ from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock - + # Get pipeline status shared data and lock pipeline_status = get_namespace_data("pipeline_status") storage_lock = get_storage_lock() - + # Check if another process is already processing the queue process_documents = False with storage_lock: + # Ensure only one worker is processing documents if not pipeline_status.get("busy", False): - # No other process is busy, we can process documents - # ่Žทๅ–ๅฝ“ๅ‰็š„ history_messages ๅˆ—่กจ + # Cleaning history_messages without breaking it as a shared list object current_history = pipeline_status.get("history_messages", []) - - # ๆธ…็ฉบๅฝ“ๅ‰ๅˆ—่กจๅ†…ๅฎนไฝ†ไฟๆŒๅŒไธ€ไธชๅˆ—่กจๅฏน่ฑก if hasattr(current_history, "clear"): current_history.clear() - - pipeline_status.update({ - "busy": True, - "job_name": "indexing files", - "job_start": datetime.now().isoformat(), - "docs": 0, - "batchs": 0, - "cur_batch": 0, - "request_pending": False, # Clear any previous request - "latest_message": "", - # ไฟๆŒไฝฟ็”จๅŒไธ€ไธชๅˆ—่กจๅฏน่ฑก - "history_messages": current_history, - }) + + pipeline_status.update( + { + "busy": True, + "job_name": "indexing files", + "job_start": datetime.now().isoformat(), + "docs": 0, + "batchs": 0, + "cur_batch": 0, + "request_pending": False, # Clear any previous request + "latest_message": "", + "history_messages": current_history, # keep it as a shared list object + } + ) process_documents = True else: # Another process is busy, just set request flag and return pipeline_status["request_pending"] = True - logger.info("Another process is already processing the document queue. Request queued.") - + logger.info( + "Another process is already processing the document queue. Request queued." + ) + if not process_documents: return - + try: # Process documents until no more documents or requests while True: @@ -734,7 +736,7 @@ class LightRAG: # Update pipeline status with document count (with lock) with storage_lock: pipeline_status["docs"] = len(to_process_docs) - + # 2. split docs into chunks, insert chunks, update doc status docs_batches = [ list(to_process_docs.items())[i : i + self.max_parallel_insert] @@ -742,11 +744,8 @@ class LightRAG: ] # Update pipeline status with batch information (directly, as it's atomic) - pipeline_status.update({ - "batchs": len(docs_batches), - "cur_batch": 0 - }) - + pipeline_status.update({"batchs": len(docs_batches), "cur_batch": 0}) + log_message = f"Number of batches to process: {len(docs_batches)}." logger.info(log_message) pipeline_status["latest_message"] = log_message @@ -757,13 +756,15 @@ class LightRAG: for batch_idx, docs_batch in enumerate(docs_batches): # Update current batch in pipeline status (directly, as it's atomic) pipeline_status["cur_batch"] = batch_idx + 1 - + async def batch( batch_idx: int, docs_batch: list[tuple[str, DocProcessingStatus]], size_batch: int, ) -> None: - log_message = f"Start processing batch {batch_idx + 1} of {size_batch}." + log_message = ( + f"Start processing batch {batch_idx + 1} of {size_batch}." + ) logger.info(log_message) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) @@ -822,7 +823,9 @@ class LightRAG: } ) except Exception as e: - logger.error(f"Failed to process document {doc_id}: {str(e)}") + logger.error( + f"Failed to process document {doc_id}: {str(e)}" + ) await self.doc_status.upsert( { doc_id: { @@ -837,7 +840,9 @@ class LightRAG: } ) continue - log_message = f"Completed batch {batch_idx + 1} of {len(docs_batches)}." + log_message = ( + f"Completed batch {batch_idx + 1} of {len(docs_batches)}." + ) logger.info(log_message) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) @@ -846,7 +851,7 @@ class LightRAG: await asyncio.gather(*batches) await self._insert_done() - + # Check if there's a pending request to process more documents (with lock) has_pending_request = False with storage_lock: @@ -854,15 +859,15 @@ class LightRAG: if has_pending_request: # Clear the request flag before checking for more documents pipeline_status["request_pending"] = False - + if not has_pending_request: break - + log_message = "Processing additional documents due to pending request" logger.info(log_message) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - + finally: # Always reset busy status when done or if an exception occurs (with lock) with storage_lock: @@ -901,12 +906,13 @@ class LightRAG: if storage_inst is not None ] await asyncio.gather(*tasks) - + log_message = "All Insert done" logger.info(log_message) - + # ่Žทๅ– pipeline_status ๅนถๆ›ดๆ–ฐ latest_message ๅ’Œ history_messages from lightrag.kg.shared_storage import get_namespace_data + pipeline_status = get_namespace_data("pipeline_status") pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) diff --git a/lightrag/operate.py b/lightrag/operate.py index 44a68655..59dfb063 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -336,8 +336,9 @@ async def extract_entities( global_config: dict[str, str], llm_response_cache: BaseKVStorage | None = None, ) -> None: - # ๅœจๅ‡ฝๆ•ฐๅผ€ๅง‹ๅค„ๆทปๅŠ ่Žทๅ– pipeline_status ็š„ไปฃ็  + from lightrag.kg.shared_storage import get_namespace_data + pipeline_status = get_namespace_data("pipeline_status") use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] diff --git a/lightrag/utils.py b/lightrag/utils.py index 3ec96112..5e579a06 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -75,50 +75,42 @@ def set_logger(log_file: str, level: int = logging.DEBUG): log_file: Path to the log file level: Logging level (e.g. logging.DEBUG, logging.INFO) """ - # ่ฎพ็ฝฎๆ—ฅๅฟ—็บงๅˆซ + logger.setLevel(level) - - # ็กฎไฟไฝฟ็”จ็ปๅฏน่ทฏๅพ„ log_file = os.path.abspath(log_file) - - # ๅˆ›ๅปบๆ ผๅผๅŒ–ๅ™จ formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) - - # ๆฃ€ๆŸฅๆ˜ฏๅฆๅทฒ็ปๆœ‰ๆ–‡ไปถๅค„็†ๅ™จ has_file_handler = False has_console_handler = False - - # ๆฃ€ๆŸฅ็Žฐๆœ‰ๅค„็†ๅ™จ + for handler in logger.handlers: if isinstance(handler, logging.FileHandler): has_file_handler = True - elif isinstance(handler, logging.StreamHandler) and not isinstance(handler, logging.FileHandler): + elif isinstance(handler, logging.StreamHandler) and not isinstance( + handler, logging.FileHandler + ): has_console_handler = True - - # ๅฆ‚ๆžœๆฒกๆœ‰ๆ–‡ไปถๅค„็†ๅ™จ๏ผŒๆทปๅŠ ไธ€ไธช + if not has_file_handler: - # ไฝฟ็”จ RotatingFileHandler ไปฃๆ›ฟ FileHandler from logging.handlers import RotatingFileHandler + file_handler = RotatingFileHandler( - log_file, - maxBytes=10*1024*1024, # 10MB + log_file, + maxBytes=10 * 1024 * 1024, # 10MB backupCount=5, - encoding="utf-8" + encoding="utf-8", ) file_handler.setLevel(level) file_handler.setFormatter(formatter) logger.addHandler(file_handler) - - # ๅฆ‚ๆžœๆฒกๆœ‰ๆŽงๅˆถๅฐๅค„็†ๅ™จ๏ผŒๆทปๅŠ ไธ€ไธช + if not has_console_handler: console_handler = logging.StreamHandler() console_handler.setLevel(level) console_handler.setFormatter(formatter) logger.addHandler(console_handler) - - # ่ฎพ็ฝฎๆ—ฅๅฟ—ไผ ๆ’ญไธบ False๏ผŒ้ฟๅ…้‡ๅค่พ“ๅ‡บ + logger.propagate = False diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index de2b21b6..94cbdedf 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -5,9 +5,7 @@ Start LightRAG server with Gunicorn import os import sys -import json import signal -import argparse from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data @@ -34,7 +32,6 @@ def main(): # Parse all arguments using parse_args args = parse_args(is_uvicorn_mode=False) - # Display startup information display_splash_screen(args) @@ -101,9 +98,15 @@ def main(): # Set configuration variables in gunicorn_config gunicorn_config.workers = int(os.getenv("WORKERS", args.workers)) - gunicorn_config.bind = f"{os.getenv('HOST', args.host)}:{os.getenv('PORT', args.port)}" - gunicorn_config.loglevel = args.log_level.lower() if args.log_level else os.getenv("LOG_LEVEL", "info") - + gunicorn_config.bind = ( + f"{os.getenv('HOST', args.host)}:{os.getenv('PORT', args.port)}" + ) + gunicorn_config.loglevel = ( + args.log_level.lower() + if args.log_level + else os.getenv("LOG_LEVEL", "info") + ) + # Set SSL configuration if enabled if args.ssl: gunicorn_config.certfile = args.ssl_certfile @@ -121,10 +124,12 @@ def main(): value = getattr(gunicorn_config, key) if callable(value): self.cfg.set(key, value) - - # ็กฎไฟๆญฃ็กฎๅŠ ่ฝฝ logconfig_dict - if hasattr(gunicorn_config, 'logconfig_dict'): - self.cfg.set('logconfig_dict', getattr(gunicorn_config, 'logconfig_dict')) + + + if hasattr(gunicorn_config, "logconfig_dict"): + self.cfg.set( + "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") + ) def load(self): # Import the application From 731d820bcc31190717d5ee01853205c5223045cf Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 21:46:45 +0800 Subject: [PATCH 48/77] Remove redundancy set_logger function and related calls --- lightrag/kg/shared_storage.py | 32 ++++++++++++------------ lightrag/lightrag.py | 4 +-- lightrag/operate.py | 1 - lightrag/utils.py | 46 ----------------------------------- run_with_gunicorn.py | 1 - 5 files changed, 18 insertions(+), 66 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 3a21dc5c..4cad25fa 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -80,24 +80,26 @@ def initialize_share_data(workers: int = 1): # Mark as initialized _initialized = True - + # Initialize pipeline status for document indexing control pipeline_namespace = get_namespace_data("pipeline_status") - - # ๅˆ›ๅปบไธ€ไธชๅ…ฑไบซๅˆ—่กจๅฏน่ฑก็”จไบŽ history_messages + + # Create a shared list object for history_messages history_messages = _manager.list() if is_multiprocess else [] - - pipeline_namespace.update({ - "busy": False, # Control concurrent processes - "job_name": "Default Job", # Current job name (indexing files/indexing texts) - "job_start": None, # Job start time - "docs": 0, # Total number of documents to be indexed - "batchs": 0, # Number of batches for processing documents - "cur_batch": 0, # Current processing batch - "request_pending": False, # Flag for pending request for processing - "latest_message": "", # Latest message from pipeline processing - "history_messages": history_messages, # ไฝฟ็”จๅ…ฑไบซๅˆ—่กจๅฏน่ฑก - }) + + pipeline_namespace.update( + { + "busy": False, # Control concurrent processes + "job_name": "Default Job", # Current job name (indexing files/indexing texts) + "job_start": None, # Job start time + "docs": 0, # Total number of documents to be indexed + "batchs": 0, # Number of batches for processing documents + "cur_batch": 0, # Current processing batch + "request_pending": False, # Flag for pending request for processing + "latest_message": "", # Latest message from pipeline processing + "history_messages": history_messages, # ไฝฟ็”จๅ…ฑไบซๅˆ—่กจๅฏน่ฑก + } + ) def try_initialize_namespace(namespace: str) -> bool: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 2dfcae44..9c8f84ff 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -45,7 +45,6 @@ from .utils import ( lazy_external_import, limit_async_func_call, logger, - set_logger, ) from .types import KnowledgeGraph from dotenv import load_dotenv @@ -268,7 +267,6 @@ class LightRAG: def __post_init__(self): os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) - set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") from lightrag.kg.shared_storage import ( @@ -682,7 +680,7 @@ class LightRAG: with storage_lock: # Ensure only one worker is processing documents if not pipeline_status.get("busy", False): - # Cleaning history_messages without breaking it as a shared list object + # Cleaning history_messages without breaking it as a shared list object current_history = pipeline_status.get("history_messages", []) if hasattr(current_history, "clear"): current_history.clear() diff --git a/lightrag/operate.py b/lightrag/operate.py index 59dfb063..5db5b5c6 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -336,7 +336,6 @@ async def extract_entities( global_config: dict[str, str], llm_response_cache: BaseKVStorage | None = None, ) -> None: - from lightrag.kg.shared_storage import get_namespace_data pipeline_status = get_namespace_data("pipeline_status") diff --git a/lightrag/utils.py b/lightrag/utils.py index 5e579a06..c86ad8c0 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -68,52 +68,6 @@ logger.setLevel(logging.INFO) logging.getLogger("httpx").setLevel(logging.WARNING) -def set_logger(log_file: str, level: int = logging.DEBUG): - """Set up file logging with the specified level. - - Args: - log_file: Path to the log file - level: Logging level (e.g. logging.DEBUG, logging.INFO) - """ - - logger.setLevel(level) - log_file = os.path.abspath(log_file) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - has_file_handler = False - has_console_handler = False - - for handler in logger.handlers: - if isinstance(handler, logging.FileHandler): - has_file_handler = True - elif isinstance(handler, logging.StreamHandler) and not isinstance( - handler, logging.FileHandler - ): - has_console_handler = True - - if not has_file_handler: - from logging.handlers import RotatingFileHandler - - file_handler = RotatingFileHandler( - log_file, - maxBytes=10 * 1024 * 1024, # 10MB - backupCount=5, - encoding="utf-8", - ) - file_handler.setLevel(level) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - if not has_console_handler: - console_handler = logging.StreamHandler() - console_handler.setLevel(level) - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - - logger.propagate = False - - class UnlimitedSemaphore: """A context manager that allows unlimited access.""" diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index 94cbdedf..a7692085 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -125,7 +125,6 @@ def main(): if callable(value): self.cfg.set(key, value) - if hasattr(gunicorn_config, "logconfig_dict"): self.cfg.set( "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") From aac1bdd9e6e63b8218369d30d9e82cd9dc4bf6ce Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 23:21:14 +0800 Subject: [PATCH 49/77] feat: add configurable log rotation settings via environment variables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add LOG_DIR env var for log file location โ€ข Add LOG_MAX_BYTES for max log file size โ€ข Add LOG_BACKUP_COUNT for backup count --- .env.example | 3 +++ gunicorn_config.py | 17 +++++++++++------ lightrag/api/lightrag_server.py | 13 +++++++++---- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/.env.example b/.env.example index 8a14cdb3..de9b6452 100644 --- a/.env.example +++ b/.env.example @@ -23,6 +23,9 @@ ### Logging level # LOG_LEVEL=INFO # VERBOSE=False +# LOG_DIR=/path/to/log/directory # Log file directory path, defaults to current working directory +# LOG_MAX_BYTES=10485760 # Log file max size in bytes, defaults to 10MB +# LOG_BACKUP_COUNT=5 # Number of backup files to keep, defaults to 5 ### Max async calls for LLM # MAX_ASYNC=4 diff --git a/gunicorn_config.py b/gunicorn_config.py index 810fc721..a13054e3 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -4,8 +4,13 @@ import logging from lightrag.kg.shared_storage import finalize_share_data from lightrag.api.lightrag_server import LightragPathFilter -# ่Žทๅ–ๆ—ฅๅฟ—ๆ–‡ไปถ่ทฏๅพ„ -log_file_path = os.path.abspath(os.path.join(os.getcwd(), "lightrag.log")) +# Get log directory path from environment variable +log_dir = os.getenv("LOG_DIR", os.getcwd()) +log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + +# Get log file max size and backup count from environment variables +log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB +log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups # These variables will be set by run_with_gunicorn.py workers = None @@ -25,8 +30,8 @@ timeout = int(os.getenv("TIMEOUT", 120)) keepalive = 5 # Logging configuration -errorlog = os.getenv("ERROR_LOG", log_file_path) # ้ป˜่ฎคๅ†™ๅ…ฅๅˆฐ lightrag.log -accesslog = os.getenv("ACCESS_LOG", log_file_path) # ้ป˜่ฎคๅ†™ๅ…ฅๅˆฐ lightrag.log +errorlog = os.getenv("ERROR_LOG", log_file_path) # Default write to lightrag.log +accesslog = os.getenv("ACCESS_LOG", log_file_path) # Default write to lightrag.log logconfig_dict = { "version": 1, @@ -44,8 +49,8 @@ logconfig_dict = { "class": "logging.handlers.RotatingFileHandler", "formatter": "standard", "filename": log_file_path, - "maxBytes": 10485760, # 10MB - "backupCount": 5, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, "encoding": "utf8", }, }, diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index d00d39d1..4d0a6390 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -430,8 +430,13 @@ def configure_logging(): logger.handlers = [] logger.filters = [] - # Configure basic logging - log_file_path = os.path.abspath(os.path.join(os.getcwd(), "lightrag.log")) + # Get log directory path from environment variable + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups logging.config.dictConfig( { @@ -455,8 +460,8 @@ def configure_logging(): "formatter": "detailed", "class": "logging.handlers.RotatingFileHandler", "filename": log_file_path, - "maxBytes": 10 * 1024 * 1024, # 10MB - "backupCount": 5, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, "encoding": "utf-8", }, }, From a721421bd8e8a59e1f653695d67748e5999abacd Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 01:49:26 +0800 Subject: [PATCH 50/77] Add async support and update flag mechanism for shared storage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Use asyncio.Lock instead of thread lock for single process mode โ€ข Add storage update notification system --- lightrag/kg/shared_storage.py | 86 +++++++++++++++++++++++++++++------ 1 file changed, 73 insertions(+), 13 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 4cad25fa..7ac0d625 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,7 +1,7 @@ import os import sys +import asyncio from multiprocessing.synchronize import Lock as ProcessLock -from threading import Lock as ThreadLock from multiprocessing import Manager from typing import Any, Dict, Optional, Union @@ -15,16 +15,18 @@ def direct_log(message, level="INFO"): print(f"{level}: {message}", file=sys.stderr, flush=True) -LockType = Union[ProcessLock, ThreadLock] +LockType = Union[ProcessLock, asyncio.Lock] +is_multiprocess = None +_workers = None _manager = None _initialized = None -is_multiprocess = None _global_lock: Optional[LockType] = None # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized +_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated def initialize_share_data(workers: int = 1): @@ -47,12 +49,14 @@ def initialize_share_data(workers: int = 1): """ global \ _manager, \ + _workers, \ is_multiprocess, \ is_multiprocess, \ _global_lock, \ _shared_dicts, \ _init_flags, \ - _initialized + _initialized, \ + _update_flags # Check if already initialized if _initialized: @@ -62,20 +66,23 @@ def initialize_share_data(workers: int = 1): return _manager = Manager() + _workers = workers if workers > 1: is_multiprocess = True _global_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() + _update_flags = _manager.dict() direct_log( f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" ) else: is_multiprocess = False - _global_lock = ThreadLock() + _global_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} + _update_flags = {} direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") # Mark as initialized @@ -86,7 +93,6 @@ def initialize_share_data(workers: int = 1): # Create a shared list object for history_messages history_messages = _manager.list() if is_multiprocess else [] - pipeline_namespace.update( { "busy": False, # Control concurrent processes @@ -102,6 +108,58 @@ def initialize_share_data(workers: int = 1): ) +async def get_update_flags(namespace: str): + """ + Create a updated flags of a specific namespace. + Caller must store the dict object locally and use it to determine when to update the storage. + """ + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + if is_multiprocess: + with _global_lock: + if namespace not in _update_flags: + if _manager is not None: + _update_flags[namespace] = _manager.list() + direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") + + if _manager is not None: + new_update_flag = _manager.Value('b', False) + _update_flags[namespace].append(new_update_flag) + return new_update_flag + else: + async with _global_lock: + if namespace not in _update_flags: + _update_flags[namespace] = [] + direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") + + new_update_flag = False + _update_flags[namespace].append(new_update_flag) + return new_update_flag + +async def set_update_flag(namespace: str): + """Set all update flag of namespace to indicate storage needs updating""" + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + if is_multiprocess: + with _global_lock: + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for multiprocess mode + for i in range(len(_update_flags[namespace])): + _update_flags[namespace][i].value = True + else: + async with _global_lock: + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for single process mode + for i in range(len(_update_flags[namespace])): + _update_flags[namespace][i] = True + + def try_initialize_namespace(namespace: str) -> bool: """ Try to initialize a namespace. Returns True if the current process gets initialization permission. @@ -129,7 +187,7 @@ def get_storage_lock() -> LockType: return _global_lock -def get_namespace_data(namespace: str) -> Dict[str, Any]: +async def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" if _shared_dicts is None: direct_log( @@ -138,12 +196,14 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: ) raise ValueError("Shared dictionaries not initialized") - lock = get_storage_lock() - with lock: - if namespace not in _shared_dicts: - if is_multiprocess and _manager is not None: - _shared_dicts[namespace] = _manager.dict() - else: + if is_multiprocess: + with _global_lock: + if namespace not in _shared_dicts: + if _manager is not None: + _shared_dicts[namespace] = _manager.dict() + else: + async with _global_lock: + if namespace not in _shared_dicts: _shared_dicts[namespace] = {} return _shared_dicts[namespace] From b3328542c71e9fb8e6894340d1174154fb95dec5 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 02:22:35 +0800 Subject: [PATCH 51/77] refactor: migrate synchronous locks to async locks for improved concurrency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add UnifiedLock wrapper class โ€ข Convert with blocks to async with --- lightrag/api/lightrag_server.py | 7 +- lightrag/kg/json_doc_status_impl.py | 24 +++--- lightrag/kg/json_kv_impl.py | 18 ++--- lightrag/kg/shared_storage.py | 111 +++++++++++++++++----------- lightrag/lightrag.py | 21 +++--- 5 files changed, 102 insertions(+), 79 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 4d0a6390..c49de7a4 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -143,13 +143,10 @@ def create_app(args): get_storage_lock, ) - # Get pipeline status and lock - pipeline_status = get_namespace_data("pipeline_status") - storage_lock = get_storage_lock() - # Check if a task is already running (with lock protection) + pipeline_status = await get_namespace_data("pipeline_status") should_start_task = False - with storage_lock: + async with get_storage_lock(): if not pipeline_status.get("busy", False): should_start_task = True # Only start the task if no other task is running diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 05e6da37..6a825db4 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -24,17 +24,17 @@ from .shared_storage import ( class JsonDocStatusStorage(DocStatusStorage): """JSON implementation of document status storage""" - def __post_init__(self): + async def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) - self._data = get_namespace_data(self.namespace) + self._data = await get_namespace_data(self.namespace) if need_init: loaded_data = load_json(self._file_name) or {} - with self._storage_lock: + async with self._storage_lock: self._data.update(loaded_data) logger.info( f"Loaded document status storage with {len(loaded_data)} records" @@ -42,12 +42,12 @@ class JsonDocStatusStorage(DocStatusStorage): async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - with self._storage_lock: + async with self._storage_lock: return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] - with self._storage_lock: + async with self._storage_lock: for id in ids: data = self._data.get(id, None) if data: @@ -57,7 +57,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" counts = {status.value: 0 for status in DocStatus} - with self._storage_lock: + async with self._storage_lock: for doc in self._data.values(): counts[doc["status"]] += 1 return counts @@ -67,7 +67,7 @@ class JsonDocStatusStorage(DocStatusStorage): ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" result = {} - with self._storage_lock: + async with self._storage_lock: for k, v in self._data.items(): if v["status"] == status.value: try: @@ -83,7 +83,7 @@ class JsonDocStatusStorage(DocStatusStorage): return result async def index_done_callback(self) -> None: - with self._storage_lock: + async with self._storage_lock: data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) @@ -94,21 +94,21 @@ class JsonDocStatusStorage(DocStatusStorage): if not data: return - with self._storage_lock: + async with self._storage_lock: self._data.update(data) await self.index_done_callback() async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - with self._storage_lock: + async with self._storage_lock: return self._data.get(id) async def delete(self, doc_ids: list[str]): - with self._storage_lock: + async with self._storage_lock: for doc_id in doc_ids: self._data.pop(doc_id, None) await self.index_done_callback() async def drop(self) -> None: """Drop the storage""" - with self._storage_lock: + async with self._storage_lock: self._data.clear() diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index a4ce91a5..424730c1 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -20,33 +20,33 @@ from .shared_storage import ( @final @dataclass class JsonKVStorage(BaseKVStorage): - def __post_init__(self): + async def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) - self._data = get_namespace_data(self.namespace) + self._data = await get_namespace_data(self.namespace) if need_init: loaded_data = load_json(self._file_name) or {} - with self._storage_lock: + async with self._storage_lock: self._data.update(loaded_data) logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") async def index_done_callback(self) -> None: - with self._storage_lock: + async with self._storage_lock: data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) write_json(data_dict, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: - with self._storage_lock: + async with self._storage_lock: return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - with self._storage_lock: + async with self._storage_lock: return [ ( {k: v for k, v in self._data[id].items()} @@ -57,19 +57,19 @@ class JsonKVStorage(BaseKVStorage): ] async def filter_keys(self, keys: set[str]) -> set[str]: - with self._storage_lock: + async with self._storage_lock: return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - with self._storage_lock: + async with self._storage_lock: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) async def delete(self, ids: list[str]) -> None: - with self._storage_lock: + async with self._storage_lock: for doc_id in ids: self._data.pop(doc_id, None) await self.index_done_callback() diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 7ac0d625..ef946b44 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -3,7 +3,7 @@ import sys import asyncio from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing import Manager -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, TypeVar, Generic # Define a direct print function for critical logs that must be visible in all processes @@ -15,6 +15,43 @@ def direct_log(message, level="INFO"): print(f"{level}: {message}", file=sys.stderr, flush=True) +T = TypeVar('T') + +class UnifiedLock(Generic[T]): + """็ปŸไธ€็š„้”ๅŒ…่ฃ…็ฑป๏ผŒๆไพ›ๅŒๆญฅๅ’Œๅผ‚ๆญฅ็š„็ปŸไธ€ๆŽฅๅฃ""" + def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): + self._lock = lock + self._is_async = is_async + + async def __aenter__(self) -> 'UnifiedLock[T]': + """ๅผ‚ๆญฅไธŠไธ‹ๆ–‡็ฎก็†ๅ™จๅ…ฅๅฃ""" + if self._is_async: + await self._lock.acquire() + else: + self._lock.acquire() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """ๅผ‚ๆญฅไธŠไธ‹ๆ–‡็ฎก็†ๅ™จๅ‡บๅฃ""" + if self._is_async: + self._lock.release() + else: + self._lock.release() + + def __enter__(self) -> 'UnifiedLock[T]': + """ๅŒๆญฅไธŠไธ‹ๆ–‡็ฎก็†ๅ™จๅ…ฅๅฃ๏ผˆไป…็”จไบŽๅ‘ๅŽๅ…ผๅฎน๏ผ‰""" + if self._is_async: + raise RuntimeError("Use 'async with' for asyncio.Lock") + self._lock.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ๅŒๆญฅไธŠไธ‹ๆ–‡็ฎก็†ๅ™จๅ‡บๅฃ๏ผˆไป…็”จไบŽๅ‘ๅŽๅ…ผๅฎน๏ผ‰""" + if self._is_async: + raise RuntimeError("Use 'async with' for asyncio.Lock") + self._lock.release() + + LockType = Union[ProcessLock, asyncio.Lock] is_multiprocess = None @@ -117,26 +154,21 @@ async def get_update_flags(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") - if is_multiprocess: - with _global_lock: - if namespace not in _update_flags: - if _manager is not None: - _update_flags[namespace] = _manager.list() - direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") - - if _manager is not None: - new_update_flag = _manager.Value('b', False) - _update_flags[namespace].append(new_update_flag) - return new_update_flag - else: - async with _global_lock: - if namespace not in _update_flags: + async with get_storage_lock(): + if namespace not in _update_flags: + if is_multiprocess and _manager is not None: + _update_flags[namespace] = _manager.list() + else: _update_flags[namespace] = [] - direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") - + direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") + + if is_multiprocess and _manager is not None: + new_update_flag = _manager.Value('b', False) + else: new_update_flag = False - _update_flags[namespace].append(new_update_flag) - return new_update_flag + + _update_flags[namespace].append(new_update_flag) + return new_update_flag async def set_update_flag(namespace: str): """Set all update flag of namespace to indicate storage needs updating""" @@ -144,19 +176,14 @@ async def set_update_flag(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") - if is_multiprocess: - with _global_lock: - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") - # Update flags for multiprocess mode - for i in range(len(_update_flags[namespace])): + async with get_storage_lock(): + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for both modes + for i in range(len(_update_flags[namespace])): + if is_multiprocess: _update_flags[namespace][i].value = True - else: - async with _global_lock: - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") - # Update flags for single process mode - for i in range(len(_update_flags[namespace])): + else: _update_flags[namespace][i] = True @@ -182,9 +209,12 @@ def try_initialize_namespace(namespace: str) -> bool: return False -def get_storage_lock() -> LockType: - """return storage lock for data consistency""" - return _global_lock +def get_storage_lock() -> UnifiedLock: + """return unified storage lock for data consistency""" + return UnifiedLock( + lock=_global_lock, + is_async=not is_multiprocess + ) async def get_namespace_data(namespace: str) -> Dict[str, Any]: @@ -196,14 +226,11 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]: ) raise ValueError("Shared dictionaries not initialized") - if is_multiprocess: - with _global_lock: - if namespace not in _shared_dicts: - if _manager is not None: - _shared_dicts[namespace] = _manager.dict() - else: - async with _global_lock: - if namespace not in _shared_dicts: + async with get_storage_lock(): + if namespace not in _shared_dicts: + if is_multiprocess and _manager is not None: + _shared_dicts[namespace] = _manager.dict() + else: _shared_dicts[namespace] = {} return _shared_dicts[namespace] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9c8f84ff..4b85a3b7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -672,12 +672,12 @@ class LightRAG: from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock # Get pipeline status shared data and lock - pipeline_status = get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data("pipeline_status") storage_lock = get_storage_lock() # Check if another process is already processing the queue process_documents = False - with storage_lock: + async with storage_lock: # Ensure only one worker is processing documents if not pipeline_status.get("busy", False): # Cleaning history_messages without breaking it as a shared list object @@ -732,8 +732,7 @@ class LightRAG: break # Update pipeline status with document count (with lock) - with storage_lock: - pipeline_status["docs"] = len(to_process_docs) + pipeline_status["docs"] = len(to_process_docs) # 2. split docs into chunks, insert chunks, update doc status docs_batches = [ @@ -852,7 +851,7 @@ class LightRAG: # Check if there's a pending request to process more documents (with lock) has_pending_request = False - with storage_lock: + async with storage_lock: has_pending_request = pipeline_status.get("request_pending", False) if has_pending_request: # Clear the request flag before checking for more documents @@ -867,13 +866,13 @@ class LightRAG: pipeline_status["history_messages"].append(log_message) finally: - # Always reset busy status when done or if an exception occurs (with lock) - with storage_lock: - pipeline_status["busy"] = False log_message = "Document processing pipeline completed" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # Always reset busy status when done or if an exception occurs (with lock) + async with storage_lock: + pipeline_status["busy"] = False + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: @@ -911,7 +910,7 @@ class LightRAG: # ่Žทๅ– pipeline_status ๅนถๆ›ดๆ–ฐ latest_message ๅ’Œ history_messages from lightrag.kg.shared_storage import get_namespace_data - pipeline_status = get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data("pipeline_status") pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) From fd76e00c6a4ccb3d92db959c909d8962f71bee5b Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 03:48:19 +0800 Subject: [PATCH 52/77] Refactor storage initialization to separate object creation from data loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Split __post_init__ and initialize() โ€ข Move data loading to initialize() โ€ข Add FastAPI lifespan integration --- lightrag/api/lightrag_server.py | 14 ++++++++------ lightrag/kg/json_doc_status_impl.py | 5 ++++- lightrag/kg/json_kv_impl.py | 5 ++++- lightrag/kg/nano_vector_db_impl.py | 13 ++++++++----- lightrag/kg/shared_storage.py | 24 ++++++++++++++++-------- 5 files changed, 40 insertions(+), 21 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index c49de7a4..ca0958ee 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -135,14 +135,16 @@ def create_app(args): # Initialize database connections await rag.initialize_storages() + # Import necessary functions from shared_storage + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_storage_lock, + initialize_pipeline_namespace, + ) + await initialize_pipeline_namespace() + # Auto scan documents if enabled if args.auto_scan_at_startup: - # Import necessary functions from shared_storage - from lightrag.kg.shared_storage import ( - get_namespace_data, - get_storage_lock, - ) - # Check if a task is already running (with lock protection) pipeline_status = await get_namespace_data("pipeline_status") should_start_task = False diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 6a825db4..01c657fa 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -24,11 +24,14 @@ from .shared_storage import ( class JsonDocStatusStorage(DocStatusStorage): """JSON implementation of document status storage""" - async def __post_init__(self): + def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() + self._data = None + async def initialize(self): + """Initialize storage data""" # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) self._data = await get_namespace_data(self.namespace) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 424730c1..8d707899 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -20,11 +20,14 @@ from .shared_storage import ( @final @dataclass class JsonKVStorage(BaseKVStorage): - async def __post_init__(self): + def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() + self._data = None + async def initialize(self): + """Initialize storage data""" # check need_init must before get_namespace_data need_init = try_initialize_namespace(self.namespace) self._data = await get_namespace_data(self.namespace) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index bbf991bf..86381379 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -16,15 +16,16 @@ if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB -from threading import Lock as ThreadLock +from .shared_storage import get_storage_lock @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): - # Initialize lock only for file operations - self._storage_lock = ThreadLock() + # Initialize basic attributes + self._storage_lock = get_storage_lock() + self._client = None # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -40,7 +41,9 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] - with self._storage_lock: + async def initialize(self): + """Initialize storage data""" + async with self._storage_lock: self._client = NanoVectorDB( self.embedding_func.embedding_dim, storage_file=self._client_file_name, @@ -163,5 +166,5 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self) -> None: - with self._storage_lock: + async with self._storage_lock: self._get_client().save() diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index ef946b44..5f795f0f 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -125,13 +125,21 @@ def initialize_share_data(workers: int = 1): # Mark as initialized _initialized = True - # Initialize pipeline status for document indexing control - pipeline_namespace = get_namespace_data("pipeline_status") - # Create a shared list object for history_messages - history_messages = _manager.list() if is_multiprocess else [] - pipeline_namespace.update( - { +async def initialize_pipeline_namespace(): + """ + Initialize pipeline namespace with default values. + """ + pipeline_namespace = await get_namespace_data("pipeline_status") + + async with get_storage_lock(): + # Check if already initialized by checking for required fields + if "busy" in pipeline_namespace: + return + + # Create a shared list object for history_messages + history_messages = _manager.list() if is_multiprocess else [] + pipeline_namespace.update({ "busy": False, # Control concurrent processes "job_name": "Default Job", # Current job name (indexing files/indexing texts) "job_start": None, # Job start time @@ -141,8 +149,8 @@ def initialize_share_data(workers: int = 1): "request_pending": False, # Flag for pending request for processing "latest_message": "", # Latest message from pipeline processing "history_messages": history_messages, # ไฝฟ็”จๅ…ฑไบซๅˆ—่กจๅฏน่ฑก - } - ) + }) + direct_log(f"Process {os.getpid()} Pipeline namespace initialized") async def get_update_flags(namespace: str): From d7045121394468dfd32f58feb74efbac1c6ccb5c Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 05:01:26 +0800 Subject: [PATCH 53/77] Refactor shared storage module to improve async handling and naming consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add async support for get_namespace_data โ€ข Rename get_update_flags to get_update_flag โ€ข Rename set_update_flag to set_all_update_flags โ€ข Update docstrings for clarity โ€ข Fix typos in log messages --- lightrag/api/routers/document_routes.py | 2 +- lightrag/kg/shared_storage.py | 20 +++++++++----------- lightrag/operate.py | 2 +- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 3fdbdf9e..ab5aff96 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -667,7 +667,7 @@ def create_document_routes( try: from lightrag.kg.shared_storage import get_namespace_data - pipeline_status = get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data("pipeline_status") # Convert to regular dict if it's a Manager.dict status_dict = dict(pipeline_status) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 5f795f0f..940d0e7b 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -18,13 +18,12 @@ def direct_log(message, level="INFO"): T = TypeVar('T') class UnifiedLock(Generic[T]): - """็ปŸไธ€็š„้”ๅŒ…่ฃ…็ฑป๏ผŒๆไพ›ๅŒๆญฅๅ’Œๅผ‚ๆญฅ็š„็ปŸไธ€ๆŽฅๅฃ""" + """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): self._lock = lock self._is_async = is_async async def __aenter__(self) -> 'UnifiedLock[T]': - """ๅผ‚ๆญฅไธŠไธ‹ๆ–‡็ฎก็†ๅ™จๅ…ฅๅฃ""" if self._is_async: await self._lock.acquire() else: @@ -32,21 +31,20 @@ class UnifiedLock(Generic[T]): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - """ๅผ‚ๆญฅไธŠไธ‹ๆ–‡็ฎก็†ๅ™จๅ‡บๅฃ""" if self._is_async: self._lock.release() else: self._lock.release() def __enter__(self) -> 'UnifiedLock[T]': - """ๅŒๆญฅไธŠไธ‹ๆ–‡็ฎก็†ๅ™จๅ…ฅๅฃ๏ผˆไป…็”จไบŽๅ‘ๅŽๅ…ผๅฎน๏ผ‰""" + """For backward compatibility""" if self._is_async: raise RuntimeError("Use 'async with' for asyncio.Lock") self._lock.acquire() return self def __exit__(self, exc_type, exc_val, exc_tb): - """ๅŒๆญฅไธŠไธ‹ๆ–‡็ฎก็†ๅ™จๅ‡บๅฃ๏ผˆไป…็”จไบŽๅ‘ๅŽๅ…ผๅฎน๏ผ‰""" + """For backward compatibility""" if self._is_async: raise RuntimeError("Use 'async with' for asyncio.Lock") self._lock.release() @@ -153,10 +151,10 @@ async def initialize_pipeline_namespace(): direct_log(f"Process {os.getpid()} Pipeline namespace initialized") -async def get_update_flags(namespace: str): +async def get_update_flag(namespace: str): """ - Create a updated flags of a specific namespace. - Caller must store the dict object locally and use it to determine when to update the storage. + Create a namespace's update flag for a workers. + Returen the update flag to caller for referencing or reset. """ global _update_flags if _update_flags is None: @@ -178,8 +176,8 @@ async def get_update_flags(namespace: str): _update_flags[namespace].append(new_update_flag) return new_update_flag -async def set_update_flag(namespace: str): - """Set all update flag of namespace to indicate storage needs updating""" +async def set_all_update_flags(namespace: str): + """Set all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") @@ -212,7 +210,7 @@ def try_initialize_namespace(namespace: str) -> bool: ) return True direct_log( - f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]" + f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" ) return False diff --git a/lightrag/operate.py b/lightrag/operate.py index 5db5b5c6..e90854a0 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -338,7 +338,7 @@ async def extract_entities( ) -> None: from lightrag.kg.shared_storage import get_namespace_data - pipeline_status = get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data("pipeline_status") use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ From d3de57c1e4b71a204ed428ba703d39f63c1c4c3f Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 10:37:05 +0800 Subject: [PATCH 54/77] Add multi-process support for vector database and graph storage with lock flags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Implement storage lock mechanism โ€ข Add update flag handling โ€ข Add cross-process reload detection --- lightrag/kg/nano_vector_db_impl.py | 101 ++++++++++++++++++----- lightrag/kg/networkx_impl.py | 123 ++++++++++++++++++++++------- 2 files changed, 175 insertions(+), 49 deletions(-) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 86381379..e0047a21 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -16,7 +16,12 @@ if not pm.is_installed("nano-vectordb"): pm.install("nano-vectordb") from nano_vectordb import NanoVectorDB -from .shared_storage import get_storage_lock +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @@ -24,8 +29,9 @@ from .shared_storage import get_storage_lock class NanoVectorDBStorage(BaseVectorStorage): def __post_init__(self): # Initialize basic attributes - self._storage_lock = get_storage_lock() self._client = None + self._storage_lock = None + self.storage_updated = None # Use global config value if specified, otherwise use default kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -41,17 +47,38 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self._max_batch_size = self.global_config["embedding_batch_num"] + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + async def initialize(self): """Initialize storage data""" - async with self._storage_lock: - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() - def _get_client(self): - """Check if the shtorage should be reloaded""" - return self._client + async def _get_client(self): + """Check if the storage should be reloaded""" + # Acquire lock to prevent concurrent read and write + async with self._storage_lock: + # Check if data needs to be reloaded + if (is_multiprocess and self.storage_updated.value) or \ + (not is_multiprocess and self.storage_updated): + logger.info(f"Reloading storage for {self.namespace} due to update by another process") + # Reload data + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + # Reset update flag + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + + return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") @@ -81,7 +108,8 @@ class NanoVectorDBStorage(BaseVectorStorage): if len(embeddings) == len(list_data): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - results = self._get_client().upsert(datas=list_data) + client = await self._get_client() + results = client.upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -94,7 +122,8 @@ class NanoVectorDBStorage(BaseVectorStorage): embedding = await self.embedding_func([query]) embedding = embedding[0] - results = self._get_client().query( + client = await self._get_client() + results = client.query( query=embedding, top_k=top_k, better_than_threshold=self.cosine_better_than_threshold, @@ -111,8 +140,9 @@ class NanoVectorDBStorage(BaseVectorStorage): return results @property - def client_storage(self): - return getattr(self._get_client(), "_NanoVectorDB__storage") + async def client_storage(self): + client = await self._get_client() + return getattr(client, "_NanoVectorDB__storage") async def delete(self, ids: list[str]): """Delete vectors with specified IDs @@ -121,7 +151,8 @@ class NanoVectorDBStorage(BaseVectorStorage): ids: List of vector IDs to be deleted """ try: - self._get_client().delete(ids) + client = await self._get_client() + client.delete(ids) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) @@ -136,8 +167,9 @@ class NanoVectorDBStorage(BaseVectorStorage): ) # Check if the entity exists - if self._get_client().get([entity_id]): - self._get_client().delete([entity_id]) + client = await self._get_client() + if client.get([entity_id]): + client.delete([entity_id]) logger.debug(f"Successfully deleted entity {entity_name}") else: logger.debug(f"Entity {entity_name} not found in storage") @@ -146,7 +178,8 @@ class NanoVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: try: - storage = getattr(self._get_client(), "_NanoVectorDB__storage") + client = await self._get_client() + storage = getattr(client, "_NanoVectorDB__storage") relations = [ dp for dp in storage["data"] @@ -156,7 +189,8 @@ class NanoVectorDBStorage(BaseVectorStorage): ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: - self._get_client().delete(ids_to_delete) + client = await self._get_client() + client.delete(ids_to_delete) logger.debug( f"Deleted {len(ids_to_delete)} relations for {entity_name}" ) @@ -166,5 +200,32 @@ class NanoVectorDBStorage(BaseVectorStorage): logger.error(f"Error deleting relations for {entity_name}: {e}") async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...") + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + # Reset update flag + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + client = await self._get_client() async with self._storage_lock: - self._get_client().save() + try: + # Save data to disk + client.save() + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-notification + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + return True # Return success + except Exception as e: + logger.error(f"Error saving data for {self.namespace}: {e}") + return False # Return error diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index ccf85855..37db8469 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -17,7 +17,12 @@ if not pm.is_installed("graspologic"): import networkx as nx from graspologic import embed -from threading import Lock as ThreadLock +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @@ -73,10 +78,12 @@ class NetworkXStorage(BaseGraphStorage): self._graphml_xml_file = os.path.join( self.global_config["working_dir"], f"graph_{self.namespace}.graphml" ) - self._storage_lock = ThreadLock() + self._storage_lock = None + self.storage_updated = None + self._graph = None - with self._storage_lock: - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) + # Load initial graph + preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) if preloaded_graph is not None: logger.info( f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" @@ -84,54 +91,83 @@ class NetworkXStorage(BaseGraphStorage): else: logger.info("Created new empty graph") self._graph = preloaded_graph or nx.Graph() + self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } - def _get_graph(self): - """Check if the shtorage should be reloaded""" - return self._graph + async def initialize(self): + """Initialize storage data""" + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() + + async def _get_graph(self): + """Check if the storage should be reloaded""" + # Acquire lock to prevent concurrent read and write + async with self._storage_lock: + # Check if data needs to be reloaded + if (is_multiprocess and self.storage_updated.value) or \ + (not is_multiprocess and self.storage_updated): + logger.info(f"Reloading graph for {self.namespace} due to update by another process") + # Reload data + self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + # Reset update flag + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + + return self._graph - async def index_done_callback(self) -> None: - with self._storage_lock: - NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: - return self._get_graph().has_node(node_id) + graph = await self._get_graph() + return graph.has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - return self._get_graph().has_edge(source_node_id, target_node_id) + graph = await self._get_graph() + return graph.has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: - return self._get_graph().nodes.get(node_id) + graph = await self._get_graph() + return graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: - return self._get_graph().degree(node_id) + graph = await self._get_graph() + return graph.degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id) + graph = await self._get_graph() + return graph.degree(src_id) + graph.degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - return self._get_graph().edges.get((source_node_id, target_node_id)) + graph = await self._get_graph() + return graph.edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - if self._get_graph().has_node(source_node_id): - return list(self._get_graph().edges(source_node_id)) + graph = await self._get_graph() + if graph.has_node(source_node_id): + return list(graph.edges(source_node_id)) return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - self._get_graph().add_node(node_id, **node_data) + graph = await self._get_graph() + graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - self._get_graph().add_edge(source_node_id, target_node_id, **edge_data) + graph = await self._get_graph() + graph.add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: - if self._get_graph().has_node(node_id): - self._get_graph().remove_node(node_id) + graph = await self._get_graph() + if graph.has_node(node_id): + graph.remove_node(node_id) logger.debug(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") @@ -145,7 +181,7 @@ class NetworkXStorage(BaseGraphStorage): # TODO: NOT USED async def _node2vec_embed(self): - graph = self._get_graph() + graph = await self._get_graph() embeddings, nodes = embed.node2vec_embed( graph, **self.global_config["node2vec_params"], @@ -153,24 +189,24 @@ class NetworkXStorage(BaseGraphStorage): nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids - def remove_nodes(self, nodes: list[str]): + async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes Args: nodes: List of node IDs to be deleted """ - graph = self._get_graph() + graph = await self._get_graph() for node in nodes: if graph.has_node(node): graph.remove_node(node) - def remove_edges(self, edges: list[tuple[str, str]]): + async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ - graph = self._get_graph() + graph = await self._get_graph() for source, target in edges: if graph.has_edge(source, target): graph.remove_edge(source, target) @@ -181,8 +217,9 @@ class NetworkXStorage(BaseGraphStorage): Returns: [label1, label2, ...] # Alphabetically sorted label list """ + graph = await self._get_graph() labels = set() - for node in self._get_graph().nodes(): + for node in graph.nodes(): labels.add(str(node)) # Add node id as a label # Return sorted list @@ -205,7 +242,7 @@ class NetworkXStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - graph = self._get_graph() + graph = await self._get_graph() # Handle special case for "*" label if node_label == "*": @@ -291,3 +328,31 @@ class NetworkXStorage(BaseGraphStorage): f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result + + async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...") + self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + # Reset update flag + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + graph = await self._get_graph() + async with self._storage_lock: + try: + # Save data to disk + NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file) + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-notification + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + return True # Return success + except Exception as e: + logger.error(f"Error saving graph for {self.namespace}: {e}") + return False # Return error From c07a5039b7e5e37101f4cfa4017aaa516a61f91a Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 10:48:55 +0800 Subject: [PATCH 55/77] Refactor shared storage locks to separate pipeline, storage and internal locks for deadlock preventing --- lightrag/api/lightrag_server.py | 8 ++-- lightrag/kg/shared_storage.py | 79 +++++++++++++++++++++------------ lightrag/lightrag.py | 10 ++--- 3 files changed, 59 insertions(+), 38 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index ca0958ee..f5f9f8ea 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -138,17 +138,17 @@ def create_app(args): # Import necessary functions from shared_storage from lightrag.kg.shared_storage import ( get_namespace_data, - get_storage_lock, - initialize_pipeline_namespace, + get_pipeline_status_lock, + initialize_pipeline_status, ) - await initialize_pipeline_namespace() + await initialize_pipeline_status() # Auto scan documents if enabled if args.auto_scan_at_startup: # Check if a task is already running (with lock protection) pipeline_status = await get_namespace_data("pipeline_status") should_start_task = False - async with get_storage_lock(): + async with get_pipeline_status_lock(): if not pipeline_status.get("busy", False): should_start_task = True # Only start the task if no other task is running diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 940d0e7b..237ed302 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -16,6 +16,22 @@ def direct_log(message, level="INFO"): T = TypeVar('T') +LockType = Union[ProcessLock, asyncio.Lock] + +is_multiprocess = None +_workers = None +_manager = None +_initialized = None + +# shared data for storage across processes +_shared_dicts: Optional[Dict[str, Any]] = None +_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized +_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated + +# locks for mutex access +_storage_lock: Optional[LockType] = None +_internal_lock: Optional[LockType] = None +_pipeline_status_lock: Optional[LockType] = None class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" @@ -39,30 +55,37 @@ class UnifiedLock(Generic[T]): def __enter__(self) -> 'UnifiedLock[T]': """For backward compatibility""" if self._is_async: - raise RuntimeError("Use 'async with' for asyncio.Lock") + raise RuntimeError("Use 'async with' for shared_storage lock") self._lock.acquire() return self def __exit__(self, exc_type, exc_val, exc_tb): """For backward compatibility""" if self._is_async: - raise RuntimeError("Use 'async with' for asyncio.Lock") + raise RuntimeError("Use 'async with' for shared_storage lock") self._lock.release() -LockType = Union[ProcessLock, asyncio.Lock] +def get_internal_lock() -> UnifiedLock: + """return unified storage lock for data consistency""" + return UnifiedLock( + lock=_internal_lock, + is_async=not is_multiprocess + ) -is_multiprocess = None -_workers = None -_manager = None -_initialized = None -_global_lock: Optional[LockType] = None - -# shared data for storage across processes -_shared_dicts: Optional[Dict[str, Any]] = None -_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized -_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated +def get_storage_lock() -> UnifiedLock: + """return unified storage lock for data consistency""" + return UnifiedLock( + lock=_storage_lock, + is_async=not is_multiprocess + ) +def get_pipeline_status_lock() -> UnifiedLock: + """return unified storage lock for data consistency""" + return UnifiedLock( + lock=_pipeline_status_lock, + is_async=not is_multiprocess + ) def initialize_share_data(workers: int = 1): """ @@ -87,7 +110,9 @@ def initialize_share_data(workers: int = 1): _workers, \ is_multiprocess, \ is_multiprocess, \ - _global_lock, \ + _storage_lock, \ + _internal_lock, \ + _pipeline_status_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ @@ -105,7 +130,9 @@ def initialize_share_data(workers: int = 1): if workers > 1: is_multiprocess = True - _global_lock = _manager.Lock() + _internal_lock = _manager.Lock() + _storage_lock = _manager.Lock() + _pipeline_status_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() _update_flags = _manager.dict() @@ -114,7 +141,9 @@ def initialize_share_data(workers: int = 1): ) else: is_multiprocess = False - _global_lock = asyncio.Lock() + _internal_lock = asyncio.Lock() + _storage_lock = asyncio.Lock() + _pipeline_status_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} _update_flags = {} @@ -124,13 +153,13 @@ def initialize_share_data(workers: int = 1): _initialized = True -async def initialize_pipeline_namespace(): +async def initialize_pipeline_status(): """ Initialize pipeline namespace with default values. """ pipeline_namespace = await get_namespace_data("pipeline_status") - async with get_storage_lock(): + async with get_internal_lock(): # Check if already initialized by checking for required fields if "busy" in pipeline_namespace: return @@ -160,7 +189,7 @@ async def get_update_flag(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") - async with get_storage_lock(): + async with get_internal_lock(): if namespace not in _update_flags: if is_multiprocess and _manager is not None: _update_flags[namespace] = _manager.list() @@ -182,7 +211,7 @@ async def set_all_update_flags(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") - async with get_storage_lock(): + async with get_internal_lock(): if namespace not in _update_flags: raise ValueError(f"Namespace {namespace} not found in update flags") # Update flags for both modes @@ -215,14 +244,6 @@ def try_initialize_namespace(namespace: str) -> bool: return False -def get_storage_lock() -> UnifiedLock: - """return unified storage lock for data consistency""" - return UnifiedLock( - lock=_global_lock, - is_async=not is_multiprocess - ) - - async def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" if _shared_dicts is None: @@ -232,7 +253,7 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]: ) raise ValueError("Shared dictionaries not initialized") - async with get_storage_lock(): + async with get_internal_lock(): if namespace not in _shared_dicts: if is_multiprocess and _manager is not None: _shared_dicts[namespace] = _manager.dict() diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 4b85a3b7..e7420a35 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -669,15 +669,15 @@ class LightRAG: 3. Process each chunk for entity and relation extraction 4. Update the document status """ - from lightrag.kg.shared_storage import get_namespace_data, get_storage_lock + from lightrag.kg.shared_storage import get_namespace_data, get_pipeline_status_lock # Get pipeline status shared data and lock pipeline_status = await get_namespace_data("pipeline_status") - storage_lock = get_storage_lock() + pipeline_status_lock = get_pipeline_status_lock() # Check if another process is already processing the queue process_documents = False - async with storage_lock: + async with pipeline_status_lock: # Ensure only one worker is processing documents if not pipeline_status.get("busy", False): # Cleaning history_messages without breaking it as a shared list object @@ -851,7 +851,7 @@ class LightRAG: # Check if there's a pending request to process more documents (with lock) has_pending_request = False - async with storage_lock: + async with pipeline_status_lock: has_pending_request = pipeline_status.get("request_pending", False) if has_pending_request: # Clear the request flag before checking for more documents @@ -869,7 +869,7 @@ class LightRAG: log_message = "Document processing pipeline completed" logger.info(log_message) # Always reset busy status when done or if an exception occurs (with lock) - async with storage_lock: + async with pipeline_status_lock: pipeline_status["busy"] = False pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) From d4f6dcfd54963183e10f43a7a311425cbbb4f5bd Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 12:41:30 +0800 Subject: [PATCH 56/77] Improve multi-process data synchronization and persistence in storage implementations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Remove _get_client() or _get_graph() from index_done_callback โ€ข Add return value for index_done_callback --- lightrag/kg/nano_vector_db_impl.py | 12 +++++++----- lightrag/kg/networkx_impl.py | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index e0047a21..c17189c6 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -66,7 +66,7 @@ class NanoVectorDBStorage(BaseVectorStorage): # Check if data needs to be reloaded if (is_multiprocess and self.storage_updated.value) or \ (not is_multiprocess and self.storage_updated): - logger.info(f"Reloading storage for {self.namespace} due to update by another process") + logger.info(f"Process {os.getpid()} reloading {self.namespace} due to update by another process") # Reload data self._client = NanoVectorDB( self.embedding_func.embedding_dim, @@ -199,7 +199,8 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") - async def index_done_callback(self) -> None: + async def index_done_callback(self) -> bool: + """Save data to disk""" # Check if storage was updated by another process if is_multiprocess and self.storage_updated.value: # Storage was updated by another process, reload data instead of saving @@ -213,14 +214,13 @@ class NanoVectorDBStorage(BaseVectorStorage): return False # Return error # Acquire lock and perform persistence - client = await self._get_client() async with self._storage_lock: try: # Save data to disk - client.save() + self._get_client.save() # Notify other processes that data has been updated await set_all_update_flags(self.namespace) - # Reset own update flag to avoid self-notification + # Reset own update flag to avoid self-reloading if is_multiprocess: self.storage_updated.value = False else: @@ -229,3 +229,5 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error saving data for {self.namespace}: {e}") return False # Return error + + return True # Return success diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 37db8469..2e61e6b3 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -110,7 +110,7 @@ class NetworkXStorage(BaseGraphStorage): # Check if data needs to be reloaded if (is_multiprocess and self.storage_updated.value) or \ (not is_multiprocess and self.storage_updated): - logger.info(f"Reloading graph for {self.namespace} due to update by another process") + logger.info(f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process") # Reload data self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() # Reset update flag @@ -329,7 +329,8 @@ class NetworkXStorage(BaseGraphStorage): ) return result - async def index_done_callback(self) -> None: + async def index_done_callback(self) -> bool: + """Save data to disk""" # Check if storage was updated by another process if is_multiprocess and self.storage_updated.value: # Storage was updated by another process, reload data instead of saving @@ -340,14 +341,13 @@ class NetworkXStorage(BaseGraphStorage): return False # Return error # Acquire lock and perform persistence - graph = await self._get_graph() async with self._storage_lock: try: # Save data to disk - NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file) + NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) # Notify other processes that data has been updated await set_all_update_flags(self.namespace) - # Reset own update flag to avoid self-notification + # Reset own update flag to avoid self-reloading if is_multiprocess: self.storage_updated.value = False else: @@ -356,3 +356,5 @@ class NetworkXStorage(BaseGraphStorage): except Exception as e: logger.error(f"Error saving graph for {self.namespace}: {e}") return False # Return error + + return True From 35bcfca28febc1e03adbe624e59cd29d9945e4a0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 12:42:30 +0800 Subject: [PATCH 57/77] feat: add multi-process support for FAISS vector storage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add storage update flag and locks โ€ข Support cross-process index reload โ€ข Add async initialize method --- lightrag/kg/faiss_impl.py | 78 ++++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 14 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index d0ef6ed0..f244c288 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -15,7 +15,12 @@ if not pm.is_installed("faiss"): pm.install("faiss") import faiss # type: ignore -from threading import Lock as ThreadLock +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @@ -45,29 +50,43 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim - self._storage_lock = ThreadLock() - + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # If you have a large number of vectors, you might want IVF or other indexes. # For demonstration, we use a simple IndexFlatIP. self._index = faiss.IndexFlatIP(self._dim) - # Keep a local store for metadata, IDs, etc. # Maps โ†’ metadata (including your original ID). self._id_to_meta = {} - # Attempt to load an existing index + metadata from disk - with self._storage_lock: - self._load_faiss_index() + self._load_faiss_index() - def _get_index(self): + async def initialize(self): + """Initialize storage data""" + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() + + async def _get_index(self): """Check if the shtorage should be reloaded""" + # Acquire lock to prevent concurrent read and write + with self._storage_lock: + # Check if storage was updated by another process + if (is_multiprocess and self.storage_updated.value) or \ + (not is_multiprocess and self.storage_updated): + logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process") + # Reload data + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False return self._index - async def index_done_callback(self) -> None: - with self._storage_lock: - self._save_faiss_index() - + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -135,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._remove_faiss_ids(existing_ids_to_remove) # Step 2: Add new vectors - index = self._get_index() + index = await self._get_index() start_idx = index.ntotal index.add(embeddings) @@ -163,7 +182,8 @@ class FaissVectorDBStorage(BaseVectorStorage): ) # Perform the similarity search - distances, indices = self._get_index().search(embedding, top_k) + index = await self._get_index() + distances, indices = index().search(embedding, top_k) distances = distances[0] indices = indices[0] @@ -316,3 +336,33 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.warning("Starting with an empty Faiss index.") self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} + +async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...") + with self._storage_lock: + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + self._save_faiss_index() + # Set all update flags to False + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + except Exception as e: + logger.error(f"Error saving FAISS index for {self.namespace}: {e}") + return False # Return error + + return True # Return success From 48d98005738214925921e6c77d4844aa76b8e4f3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 12:58:52 +0800 Subject: [PATCH 58/77] Enhance gunicorn config handling with env vars and command line arg priority MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add env var support for timeout/keepalive โ€ข Prioritize CLI args over env vars โ€ข Standardize default timeout to 150s --- gunicorn_config.py | 4 ++-- run_with_gunicorn.py | 34 ++++++++++++++++++++-------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index a13054e3..3f5d5db2 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -26,8 +26,8 @@ preload_app = True worker_class = "uvicorn.workers.UvicornWorker" # Other Gunicorn configurations -timeout = int(os.getenv("TIMEOUT", 120)) -keepalive = 5 +timeout = int(os.getenv("TIMEOUT", 150)) # Default 150s to match run_with_gunicorn.py +keepalive = int(os.getenv("KEEPALIVE", 5)) # Default 5s # Logging configuration errorlog = os.getenv("ERROR_LOG", log_file_path) # Default write to lightrag.log diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index a7692085..6aa2b0f3 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -96,21 +96,27 @@ def main(): # Import and configure the gunicorn_config module import gunicorn_config - # Set configuration variables in gunicorn_config - gunicorn_config.workers = int(os.getenv("WORKERS", args.workers)) - gunicorn_config.bind = ( - f"{os.getenv('HOST', args.host)}:{os.getenv('PORT', args.port)}" - ) - gunicorn_config.loglevel = ( - args.log_level.lower() - if args.log_level - else os.getenv("LOG_LEVEL", "info") - ) + # Set configuration variables in gunicorn_config, prioritizing command line arguments + gunicorn_config.workers = args.workers if args.workers else int(os.getenv("WORKERS", 1)) + + # Bind configuration prioritizes command line arguments + host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") + port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) + gunicorn_config.bind = f"{host}:{port}" + + # Log level configuration prioritizes command line arguments + gunicorn_config.loglevel = args.log_level.lower() if args.log_level else os.getenv("LOG_LEVEL", "info") - # Set SSL configuration if enabled - if args.ssl: - gunicorn_config.certfile = args.ssl_certfile - gunicorn_config.keyfile = args.ssl_keyfile + # Timeout configuration prioritizes command line arguments + gunicorn_config.timeout = args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) + + # Keepalive configuration + gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) + + # SSL configuration prioritizes command line arguments + if args.ssl or os.getenv("SSL", "").lower() in ("true", "1", "yes", "t", "on"): + gunicorn_config.certfile = args.ssl_certfile if args.ssl_certfile else os.getenv("SSL_CERTFILE") + gunicorn_config.keyfile = args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") # Set configuration options from the module for key in dir(gunicorn_config): From 41eff2ca2f82d9bbe5a5e54ef7080068d65a6fdb Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 13:35:00 +0800 Subject: [PATCH 59/77] Fix data persistence issue in NanoVectorDBStorage --- lightrag/kg/nano_vector_db_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index c17189c6..e0ecacdf 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -217,7 +217,7 @@ class NanoVectorDBStorage(BaseVectorStorage): async with self._storage_lock: try: # Save data to disk - self._get_client.save() + self._client.save() # Notify other processes that data has been updated await set_all_update_flags(self.namespace) # Reset own update flag to avoid self-reloading From 40e9e26edb0693145aae156d8ecd3e034378cbe7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 14:58:26 +0800 Subject: [PATCH 60/77] feat: add update flags status to API health endpoint --- lightrag/api/lightrag_server.py | 17 ++++++++++------- lightrag/kg/shared_storage.py | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index f5f9f8ea..76901b90 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -39,6 +39,12 @@ from .routers.graph_routes import create_graph_routes from .routers.ollama_api import OllamaAPI from lightrag.utils import logger, set_verbose_debug +from lightrag.kg.shared_storage import ( + get_namespace_data, + get_pipeline_status_lock, + initialize_pipeline_status, + get_all_update_flags_status, +) # Load environment variables load_dotenv(override=True) @@ -134,13 +140,6 @@ def create_app(args): try: # Initialize database connections await rag.initialize_storages() - - # Import necessary functions from shared_storage - from lightrag.kg.shared_storage import ( - get_namespace_data, - get_pipeline_status_lock, - initialize_pipeline_status, - ) await initialize_pipeline_status() # Auto scan documents if enabled @@ -376,6 +375,9 @@ def create_app(args): @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" + # Get update flags status for all namespaces + update_status = await get_all_update_flags_status() + return { "status": "healthy", "working_directory": str(args.working_dir), @@ -395,6 +397,7 @@ def create_app(args): "graph_storage": args.graph_storage, "vector_storage": args.vector_storage, }, + "update_status": update_status, } # Webui mount webui/index.html diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 237ed302..27d23f2e 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -222,6 +222,30 @@ async def set_all_update_flags(namespace: str): _update_flags[namespace][i] = True +async def get_all_update_flags_status() -> Dict[str, list]: + """ + Get update flags status for all namespaces. + + Returns: + Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses + """ + if _update_flags is None: + return {} + + result = {} + async with get_internal_lock(): + for namespace, flags in _update_flags.items(): + worker_statuses = [] + for flag in flags: + if is_multiprocess: + worker_statuses.append(flag.value) + else: + worker_statuses.append(flag) + result[namespace] = worker_statuses + + return result + + def try_initialize_namespace(namespace: str) -> bool: """ Try to initialize a namespace. Returns True if the current process gets initialization permission. From ab704aae47bbfc0c6ecef86d5a8d6c48e50b57a6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 15:13:21 +0800 Subject: [PATCH 61/77] Config log setting for very woker properly for Gunicorn mode. --- gunicorn_config.py | 56 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/gunicorn_config.py b/gunicorn_config.py index 3f5d5db2..7f9b4d58 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -134,18 +134,54 @@ def post_fork(server, worker): Executed after a worker has been forked. This is a good place to set up worker-specific configurations. """ - # Set lightrag logger level in worker processes using gunicorn's loglevel - from lightrag.utils import logger + # Configure formatters + detailed_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + simple_formatter = logging.Formatter("%(levelname)s: %(message)s") - logger.setLevel(loglevel.upper()) + def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False): + """Set up a logger with console and file handlers""" + logger_instance = logging.getLogger(logger_name) + logger_instance.setLevel(level) + logger_instance.handlers = [] # Clear existing handlers + logger_instance.propagate = False - # Disable uvicorn.error logger in worker processes + # Add console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(simple_formatter) + console_handler.setLevel(level) + logger_instance.addHandler(console_handler) + + # Add file handler + file_handler = logging.handlers.RotatingFileHandler( + filename=log_file_path, + maxBytes=log_max_bytes, + backupCount=log_backup_count, + encoding="utf-8", + ) + file_handler.setFormatter(detailed_formatter) + file_handler.setLevel(level) + logger_instance.addHandler(file_handler) + + # Add path filter if requested + if add_filter: + path_filter = LightragPathFilter() + logger_instance.addFilter(path_filter) + + # Set up main loggers + log_level = loglevel.upper() if loglevel else "INFO" + setup_logger("uvicorn", log_level) + setup_logger("uvicorn.access", log_level, add_filter=True) + setup_logger("lightrag", log_level, add_filter=True) + + # Set up lightrag submodule loggers + for name in logging.root.manager.loggerDict: + if name.startswith("lightrag."): + setup_logger(name, log_level, add_filter=True) + + # Disable uvicorn.error logger uvicorn_error_logger = logging.getLogger("uvicorn.error") - uvicorn_error_logger.setLevel(logging.CRITICAL) uvicorn_error_logger.handlers = [] + uvicorn_error_logger.setLevel(logging.CRITICAL) uvicorn_error_logger.propagate = False - - # Add log filter to uvicorn.access handler in worker processes - uvicorn_access_logger = logging.getLogger("uvicorn.access") - path_filter = LightragPathFilter() - uvicorn_access_logger.addFilter(path_filter) From 3511b9805c002b1b7c57fd11b0a846874a1a42ff Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 15:34:01 +0800 Subject: [PATCH 62/77] Add auto-installation of gunicorn if not present in environment --- run_with_gunicorn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index 6aa2b0f3..decd91de 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -45,6 +45,12 @@ def main(): print(f"Workers setting: {args.workers}") print("=" * 80 + "\n") + # Check and install gunicorn if not present + import pipmaster as pm + if not pm.is_installed("gunicorn"): + print("Installing gunicorn...") + pm.install("gunicorn") + # Import Gunicorn's StandaloneApplication from gunicorn.app.base import BaseApplication From d18eb52ccc15a191c4733b52ce64bd76fe041873 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 15:38:39 +0800 Subject: [PATCH 63/77] Add type ignore comments for asyncpg imports to suppress mypy errors --- lightrag/kg/postgres_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index c91d23f0..10883a88 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -38,8 +38,8 @@ import pipmaster as pm if not pm.is_installed("asyncpg"): pm.install("asyncpg") -import asyncpg -from asyncpg import Pool +import asyncpg # type: ignore +from asyncpg import Pool # type: ignore class PostgreSQLDB: From e3a40c2fdbc041e176c642ece9e576385d2b0502 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 16:23:34 +0800 Subject: [PATCH 64/77] Fix linting --- lightrag/kg/faiss_impl.py | 65 ++++++++++++++------------- lightrag/kg/nano_vector_db_impl.py | 17 ++++--- lightrag/kg/networkx_impl.py | 30 ++++++++----- lightrag/kg/postgres_impl.py | 4 +- lightrag/kg/shared_storage.py | 71 +++++++++++++++--------------- lightrag/lightrag.py | 5 ++- run_with_gunicorn.py | 41 ++++++++++++----- 7 files changed, 138 insertions(+), 95 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index f244c288..bb4d47ec 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -50,7 +50,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim - + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # If you have a large number of vectors, you might want IVF or other indexes. # For demonstration, we use a simple IndexFlatIP. @@ -73,9 +73,12 @@ class FaissVectorDBStorage(BaseVectorStorage): # Acquire lock to prevent concurrent read and write with self._storage_lock: # Check if storage was updated by another process - if (is_multiprocess and self.storage_updated.value) or \ - (not is_multiprocess and self.storage_updated): - logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process") + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process" + ) # Reload data self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} @@ -86,7 +89,6 @@ class FaissVectorDBStorage(BaseVectorStorage): self.storage_updated = False return self._index - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -337,32 +339,35 @@ class FaissVectorDBStorage(BaseVectorStorage): self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} + async def index_done_callback(self) -> None: - # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: - # Storage was updated by another process, reload data instead of saving - logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...") - with self._storage_lock: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta = {} - self._load_faiss_index() + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Storage for FAISS {self.namespace} was updated by another process, reloading..." + ) + with self._storage_lock: + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + self._save_faiss_index() + # Set all update flags to False + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: self.storage_updated.value = False + else: + self.storage_updated = False + except Exception as e: + logger.error(f"Error saving FAISS index for {self.namespace}: {e}") return False # Return error - # Acquire lock and perform persistence - async with self._storage_lock: - try: - # Save data to disk - self._save_faiss_index() - # Set all update flags to False - await set_all_update_flags(self.namespace) - # Reset own update flag to avoid self-reloading - if is_multiprocess: - self.storage_updated.value = False - else: - self.storage_updated = False - except Exception as e: - logger.error(f"Error saving FAISS index for {self.namespace}: {e}") - return False # Return error - - return True # Return success + return True # Return success diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index e0ecacdf..07c800de 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -64,9 +64,12 @@ class NanoVectorDBStorage(BaseVectorStorage): # Acquire lock to prevent concurrent read and write async with self._storage_lock: # Check if data needs to be reloaded - if (is_multiprocess and self.storage_updated.value) or \ - (not is_multiprocess and self.storage_updated): - logger.info(f"Process {os.getpid()} reloading {self.namespace} due to update by another process") + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} reloading {self.namespace} due to update by another process" + ) # Reload data self._client = NanoVectorDB( self.embedding_func.embedding_dim, @@ -77,7 +80,7 @@ class NanoVectorDBStorage(BaseVectorStorage): self.storage_updated.value = False else: self.storage_updated = False - + return self._client async def upsert(self, data: dict[str, dict[str, Any]]) -> None: @@ -204,7 +207,9 @@ class NanoVectorDBStorage(BaseVectorStorage): # Check if storage was updated by another process if is_multiprocess and self.storage_updated.value: # Storage was updated by another process, reload data instead of saving - logger.warning(f"Storage for {self.namespace} was updated by another process, reloading...") + logger.warning( + f"Storage for {self.namespace} was updated by another process, reloading..." + ) self._client = NanoVectorDB( self.embedding_func.embedding_dim, storage_file=self._client_file_name, @@ -212,7 +217,7 @@ class NanoVectorDBStorage(BaseVectorStorage): # Reset update flag self.storage_updated.value = False return False # Return error - + # Acquire lock and perform persistence async with self._storage_lock: try: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 2e61e6b3..f11e9c0e 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -91,7 +91,7 @@ class NetworkXStorage(BaseGraphStorage): else: logger.info("Created new empty graph") self._graph = preloaded_graph or nx.Graph() - + self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } @@ -108,19 +108,23 @@ class NetworkXStorage(BaseGraphStorage): # Acquire lock to prevent concurrent read and write async with self._storage_lock: # Check if data needs to be reloaded - if (is_multiprocess and self.storage_updated.value) or \ - (not is_multiprocess and self.storage_updated): - logger.info(f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process") + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): + logger.info( + f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process" + ) # Reload data - self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + self._graph = ( + NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + ) # Reset update flag if is_multiprocess: self.storage_updated.value = False else: self.storage_updated = False - - return self._graph + return self._graph async def has_node(self, node_id: str) -> bool: graph = await self._get_graph() @@ -334,12 +338,16 @@ class NetworkXStorage(BaseGraphStorage): # Check if storage was updated by another process if is_multiprocess and self.storage_updated.value: # Storage was updated by another process, reload data instead of saving - logger.warning(f"Graph for {self.namespace} was updated by another process, reloading...") - self._graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + logger.warning( + f"Graph for {self.namespace} was updated by another process, reloading..." + ) + self._graph = ( + NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + ) # Reset update flag self.storage_updated.value = False return False # Return error - + # Acquire lock and perform persistence async with self._storage_lock: try: @@ -356,5 +364,5 @@ class NetworkXStorage(BaseGraphStorage): except Exception as e: logger.error(f"Error saving graph for {self.namespace}: {e}") return False # Return error - + return True diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 10883a88..51044be5 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -38,8 +38,8 @@ import pipmaster as pm if not pm.is_installed("asyncpg"): pm.install("asyncpg") -import asyncpg # type: ignore -from asyncpg import Pool # type: ignore +import asyncpg # type: ignore +from asyncpg import Pool # type: ignore class PostgreSQLDB: diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 27d23f2e..acebafa7 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -15,7 +15,7 @@ def direct_log(message, level="INFO"): print(f"{level}: {message}", file=sys.stderr, flush=True) -T = TypeVar('T') +T = TypeVar("T") LockType = Union[ProcessLock, asyncio.Lock] is_multiprocess = None @@ -26,20 +26,22 @@ _initialized = None # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized -_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated +_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated # locks for mutex access _storage_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None + class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" + def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): self._lock = lock self._is_async = is_async - async def __aenter__(self) -> 'UnifiedLock[T]': + async def __aenter__(self) -> "UnifiedLock[T]": if self._is_async: await self._lock.acquire() else: @@ -52,7 +54,7 @@ class UnifiedLock(Generic[T]): else: self._lock.release() - def __enter__(self) -> 'UnifiedLock[T]': + def __enter__(self) -> "UnifiedLock[T]": """For backward compatibility""" if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") @@ -68,24 +70,18 @@ class UnifiedLock(Generic[T]): def get_internal_lock() -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock( - lock=_internal_lock, - is_async=not is_multiprocess - ) + return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess) + def get_storage_lock() -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock( - lock=_storage_lock, - is_async=not is_multiprocess - ) + return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess) + def get_pipeline_status_lock() -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock( - lock=_pipeline_status_lock, - is_async=not is_multiprocess - ) + return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess) + def initialize_share_data(workers: int = 1): """ @@ -166,17 +162,19 @@ async def initialize_pipeline_status(): # Create a shared list object for history_messages history_messages = _manager.list() if is_multiprocess else [] - pipeline_namespace.update({ - "busy": False, # Control concurrent processes - "job_name": "Default Job", # Current job name (indexing files/indexing texts) - "job_start": None, # Job start time - "docs": 0, # Total number of documents to be indexed - "batchs": 0, # Number of batches for processing documents - "cur_batch": 0, # Current processing batch - "request_pending": False, # Flag for pending request for processing - "latest_message": "", # Latest message from pipeline processing - "history_messages": history_messages, # ไฝฟ็”จๅ…ฑไบซๅˆ—่กจๅฏน่ฑก - }) + pipeline_namespace.update( + { + "busy": False, # Control concurrent processes + "job_name": "Default Job", # Current job name (indexing files/indexing texts) + "job_start": None, # Job start time + "docs": 0, # Total number of documents to be indexed + "batchs": 0, # Number of batches for processing documents + "cur_batch": 0, # Current processing batch + "request_pending": False, # Flag for pending request for processing + "latest_message": "", # Latest message from pipeline processing + "history_messages": history_messages, # ไฝฟ็”จๅ…ฑไบซๅˆ—่กจๅฏน่ฑก + } + ) direct_log(f"Process {os.getpid()} Pipeline namespace initialized") @@ -195,22 +193,25 @@ async def get_update_flag(namespace: str): _update_flags[namespace] = _manager.list() else: _update_flags[namespace] = [] - direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") - + direct_log( + f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]" + ) + if is_multiprocess and _manager is not None: - new_update_flag = _manager.Value('b', False) + new_update_flag = _manager.Value("b", False) else: new_update_flag = False - + _update_flags[namespace].append(new_update_flag) return new_update_flag + async def set_all_update_flags(namespace: str): """Set all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") - + async with get_internal_lock(): if namespace not in _update_flags: raise ValueError(f"Namespace {namespace} not found in update flags") @@ -225,13 +226,13 @@ async def set_all_update_flags(namespace: str): async def get_all_update_flags_status() -> Dict[str, list]: """ Get update flags status for all namespaces. - + Returns: Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses """ if _update_flags is None: return {} - + result = {} async with get_internal_lock(): for namespace, flags in _update_flags.items(): @@ -242,7 +243,7 @@ async def get_all_update_flags_status() -> Dict[str, list]: else: worker_statuses.append(flag) result[namespace] = worker_statuses - + return result diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6008b39c..44b77ae7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -696,7 +696,10 @@ class LightRAG: 3. Process each chunk for entity and relation extraction 4. Update the document status """ - from lightrag.kg.shared_storage import get_namespace_data, get_pipeline_status_lock + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_pipeline_status_lock, + ) # Get pipeline status shared data and lock pipeline_status = await get_namespace_data("pipeline_status") diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index decd91de..644e6e87 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -47,10 +47,11 @@ def main(): # Check and install gunicorn if not present import pipmaster as pm + if not pm.is_installed("gunicorn"): print("Installing gunicorn...") pm.install("gunicorn") - + # Import Gunicorn's StandaloneApplication from gunicorn.app.base import BaseApplication @@ -103,26 +104,46 @@ def main(): import gunicorn_config # Set configuration variables in gunicorn_config, prioritizing command line arguments - gunicorn_config.workers = args.workers if args.workers else int(os.getenv("WORKERS", 1)) - + gunicorn_config.workers = ( + args.workers if args.workers else int(os.getenv("WORKERS", 1)) + ) + # Bind configuration prioritizes command line arguments host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) gunicorn_config.bind = f"{host}:{port}" - + # Log level configuration prioritizes command line arguments - gunicorn_config.loglevel = args.log_level.lower() if args.log_level else os.getenv("LOG_LEVEL", "info") + gunicorn_config.loglevel = ( + args.log_level.lower() + if args.log_level + else os.getenv("LOG_LEVEL", "info") + ) # Timeout configuration prioritizes command line arguments - gunicorn_config.timeout = args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) - + gunicorn_config.timeout = ( + args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) + ) + # Keepalive configuration gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) # SSL configuration prioritizes command line arguments - if args.ssl or os.getenv("SSL", "").lower() in ("true", "1", "yes", "t", "on"): - gunicorn_config.certfile = args.ssl_certfile if args.ssl_certfile else os.getenv("SSL_CERTFILE") - gunicorn_config.keyfile = args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") + if args.ssl or os.getenv("SSL", "").lower() in ( + "true", + "1", + "yes", + "t", + "on", + ): + gunicorn_config.certfile = ( + args.ssl_certfile + if args.ssl_certfile + else os.getenv("SSL_CERTFILE") + ) + gunicorn_config.keyfile = ( + args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") + ) # Set configuration options from the module for key in dir(gunicorn_config): From 9aef112d512c51c2d53ea1072bf6a6f35c44bcfc Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 22:27:12 +0800 Subject: [PATCH 65/77] Fix incorrect comment about update flag behavior in FAISS implementation --- lightrag/kg/faiss_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index bb4d47ec..e71f77a8 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -359,7 +359,7 @@ async def index_done_callback(self) -> None: try: # Save data to disk self._save_faiss_index() - # Set all update flags to False + # Notify other processes that data has been updated await set_all_update_flags(self.namespace) # Reset own update flag to avoid self-reloading if is_multiprocess: From e8d0d065f3bcb2b0ebce3c5d9c87cdcd701eab8e Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 23:35:09 +0800 Subject: [PATCH 66/77] fix: Improve async handling and FAISS storage reliability - Add async context manager support - Fix embedding data type conversion - Improve error handling in FAISS ops - Add multiprocess storage sync --- lightrag/api/README.md | 2 +- lightrag/kg/faiss_impl.py | 74 +++++++++++++++++++-------------------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 86f18271..35062cad 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -186,7 +186,7 @@ LightRAG supports binding to various LLM/Embedding backends: * openai & openai compatible * azure_openai -Use environment variables `LLM_BINDING ` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING ` or CLI argument `--embedding-binding` to select LLM backend type. +Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type. ### Storage Types Supported diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index e71f77a8..940ba73d 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -71,7 +71,7 @@ class FaissVectorDBStorage(BaseVectorStorage): async def _get_index(self): """Check if the shtorage should be reloaded""" # Acquire lock to prevent concurrent read and write - with self._storage_lock: + async with self._storage_lock: # Check if storage was updated by another process if (is_multiprocess and self.storage_updated.value) or ( not is_multiprocess and self.storage_updated @@ -139,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage): ) return [] - # Normalize embeddings for cosine similarity (in-place) + # Convert to float32 and normalize embeddings for cosine similarity (in-place) + embeddings = embeddings.astype(np.float32) faiss.normalize_L2(embeddings) # Upsert logic: @@ -153,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage): existing_ids_to_remove.append(faiss_internal_id) if existing_ids_to_remove: - self._remove_faiss_ids(existing_ids_to_remove) + await self._remove_faiss_ids(existing_ids_to_remove) # Step 2: Add new vectors index = await self._get_index() @@ -185,7 +186,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Perform the similarity search index = await self._get_index() - distances, indices = index().search(embedding, top_k) + distances, indices = index.search(embedding, top_k) distances = distances[0] indices = indices[0] @@ -229,7 +230,7 @@ class FaissVectorDBStorage(BaseVectorStorage): to_remove.append(fid) if to_remove: - self._remove_faiss_ids(to_remove) + await self._remove_faiss_ids(to_remove) logger.debug( f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" ) @@ -251,7 +252,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Found {len(relations)} relations for {entity_name}") if relations: - self._remove_faiss_ids(relations) + await self._remove_faiss_ids(relations) logger.debug(f"Deleted {len(relations)} relations for {entity_name}") # -------------------------------------------------------------------------------- @@ -267,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage): return fid return None - def _remove_faiss_ids(self, fid_list): + async def _remove_faiss_ids(self, fid_list): """ Remove a list of internal Faiss IDs from the index. Because IndexFlatIP doesn't support 'removals', @@ -283,7 +284,7 @@ class FaissVectorDBStorage(BaseVectorStorage): vectors_to_keep.append(vec_meta["__vector__"]) # stored as list new_id_to_meta[new_fid] = vec_meta - with self._storage_lock: + async with self._storage_lock: # Re-init index self._index = faiss.IndexFlatIP(self._dim) if vectors_to_keep: @@ -339,35 +340,34 @@ class FaissVectorDBStorage(BaseVectorStorage): self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} - -async def index_done_callback(self) -> None: - # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: - # Storage was updated by another process, reload data instead of saving - logger.warning( - f"Storage for FAISS {self.namespace} was updated by another process, reloading..." - ) - with self._storage_lock: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta = {} - self._load_faiss_index() - self.storage_updated.value = False - return False # Return error - - # Acquire lock and perform persistence - async with self._storage_lock: - try: - # Save data to disk - self._save_faiss_index() - # Notify other processes that data has been updated - await set_all_update_flags(self.namespace) - # Reset own update flag to avoid self-reloading - if is_multiprocess: + async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Storage for FAISS {self.namespace} was updated by another process, reloading..." + ) + async with self._storage_lock: + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() self.storage_updated.value = False - else: - self.storage_updated = False - except Exception as e: - logger.error(f"Error saving FAISS index for {self.namespace}: {e}") return False # Return error - return True # Return success + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + self._save_faiss_index() + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + except Exception as e: + logger.error(f"Error saving FAISS index for {self.namespace}: {e}") + return False # Return error + + return True # Return success From f76cf98dbd072e99f1ab275df0e488ee5f3898a0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 00:04:59 +0800 Subject: [PATCH 67/77] Add automatic dependency checking and installation for server startup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Added check_and_install_dependencies() โ€ข Install missing dependencies automatically --- lightrag/api/lightrag_server.py | 19 +++++++++++++++++++ run_with_gunicorn.py | 25 ++++++++++++++++++------- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 76901b90..0b4a78b8 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -12,6 +12,7 @@ import os import logging import logging.config import uvicorn +import pipmaster as pm from fastapi.staticfiles import StaticFiles from pathlib import Path import configparser @@ -501,6 +502,21 @@ def configure_logging(): ) +def check_and_install_dependencies(): + """Check and install required dependencies""" + required_packages = [ + "uvicorn", + "tiktoken", + "fastapi", + # Add other required packages here + ] + + for package in required_packages: + if not pm.is_installed(package): + print(f"Installing {package}...") + pm.install(package) + print(f"{package} installed successfully") + def main(): # Check if running under Gunicorn if "GUNICORN_CMD_ARGS" in os.environ: @@ -508,6 +524,9 @@ def main(): print("Running under Gunicorn - worker management handled by Gunicorn") return + # Check and install dependencies + check_and_install_dependencies() + from multiprocessing import freeze_support freeze_support() diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index 644e6e87..e9d1adae 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -6,9 +6,24 @@ Start LightRAG server with Gunicorn import os import sys import signal +import pipmaster as pm from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data +def check_and_install_dependencies(): + """Check and install required dependencies""" + required_packages = [ + "gunicorn", + "tiktoken", + # Add other required packages here + ] + + for package in required_packages: + if not pm.is_installed(package): + print(f"Installing {package}...") + pm.install(package) + print(f"{package} installed successfully") + # Signal handler for graceful shutdown def signal_handler(sig, frame): @@ -25,6 +40,9 @@ def signal_handler(sig, frame): def main(): + # Check and install dependencies + check_and_install_dependencies() + # Register signal handlers for graceful shutdown signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGTERM, signal_handler) # kill command @@ -45,13 +63,6 @@ def main(): print(f"Workers setting: {args.workers}") print("=" * 80 + "\n") - # Check and install gunicorn if not present - import pipmaster as pm - - if not pm.is_installed("gunicorn"): - print("Installing gunicorn...") - pm.install("gunicorn") - # Import Gunicorn's StandaloneApplication from gunicorn.app.base import BaseApplication From 8d6960f2805cb3ce91c834bb95218e57af86af2d Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 00:13:11 +0800 Subject: [PATCH 68/77] Fix linting --- lightrag/api/lightrag_server.py | 3 ++- run_with_gunicorn.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 0b4a78b8..5f2c437f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -510,13 +510,14 @@ def check_and_install_dependencies(): "fastapi", # Add other required packages here ] - + for package in required_packages: if not pm.is_installed(package): print(f"Installing {package}...") pm.install(package) print(f"{package} installed successfully") + def main(): # Check if running under Gunicorn if "GUNICORN_CMD_ARGS" in os.environ: diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index e9d1adae..69124e31 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -10,6 +10,7 @@ import pipmaster as pm from lightrag.api.utils_api import parse_args, display_splash_screen from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data + def check_and_install_dependencies(): """Check and install required dependencies""" required_packages = [ @@ -17,7 +18,7 @@ def check_and_install_dependencies(): "tiktoken", # Add other required packages here ] - + for package in required_packages: if not pm.is_installed(package): print(f"Installing {package}...") @@ -42,7 +43,7 @@ def signal_handler(sig, frame): def main(): # Check and install dependencies check_and_install_dependencies() - + # Register signal handlers for graceful shutdown signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGTERM, signal_handler) # kill command From 7cd25fe5abc36fc24d3a4326aa45d0a5060c8037 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 01:00:27 +0800 Subject: [PATCH 69/77] Improve shared storage cleanup and clarify initialization in multi-worker setup --- lightrag/kg/shared_storage.py | 51 +++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index acebafa7..c8c154aa 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -91,7 +91,7 @@ def initialize_share_data(workers: int = 1): master process before forking worker processes, allowing all workers to share the same initialized data. - In single-process mode, this function is called during LightRAG object initialization. + In single-process mode, this function is called in FASTAPI lifespan function. The function determines whether to use cross-process shared variables for data storage based on the number of workers. If workers=1, it uses thread locks and local dictionaries. @@ -105,7 +105,6 @@ def initialize_share_data(workers: int = 1): _manager, \ _workers, \ is_multiprocess, \ - is_multiprocess, \ _storage_lock, \ _internal_lock, \ _pipeline_status_lock, \ @@ -152,6 +151,7 @@ def initialize_share_data(workers: int = 1): async def initialize_pipeline_status(): """ Initialize pipeline namespace with default values. + This function is called during FASTAPI lifespan for each worker. """ pipeline_namespace = await get_namespace_data("pipeline_status") @@ -249,8 +249,8 @@ async def get_all_update_flags_status() -> Dict[str, list]: def try_initialize_namespace(namespace: str) -> bool: """ - Try to initialize a namespace. Returns True if the current process gets initialization permission. - Uses atomic operations on shared dictionaries to ensure only one process can successfully initialize. + Returns True if the current worker(process) gets initialization permission for loading data later. + The worker does not get the permission is prohibited to load data from files. """ global _init_flags, _manager @@ -270,10 +270,10 @@ def try_initialize_namespace(namespace: str) -> bool: async def get_namespace_data(namespace: str) -> Dict[str, Any]: - """get storage space for specific storage type(namespace)""" + """get the shared data reference for specific namespace""" if _shared_dicts is None: direct_log( - f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}", + f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}", level="ERROR", ) raise ValueError("Shared dictionaries not initialized") @@ -301,10 +301,13 @@ def finalize_share_data(): global \ _manager, \ is_multiprocess, \ - _global_lock, \ + _storage_lock, \ + _internal_lock, \ + _pipeline_status_lock, \ _shared_dicts, \ _init_flags, \ - _initialized + _initialized, \ + _update_flags # Check if already initialized if not _initialized: @@ -320,13 +323,36 @@ def finalize_share_data(): # In multi-process mode, shut down the Manager if is_multiprocess and _manager is not None: try: - # Clear shared dictionaries first + # Clear shared resources before shutting down Manager if _shared_dicts is not None: + # Clear pipeline status history messages first if exists + try: + pipeline_status = _shared_dicts.get("pipeline_status", {}) + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].clear() + except Exception: + pass # Ignore any errors during history messages cleanup _shared_dicts.clear() if _init_flags is not None: _init_flags.clear() + if _update_flags is not None: + # Clear each namespace's update flags list and Value objects + try: + for namespace in _update_flags: + flags_list = _update_flags[namespace] + if isinstance(flags_list, list): + # Clear Value objects in the list + for flag in flags_list: + if hasattr( + flag, "value" + ): # Check if it's a Value object + flag.value = False + flags_list.clear() + except Exception: + pass # Ignore any errors during update flags cleanup + _update_flags.clear() - # Shut down the Manager + # Shut down the Manager - this will automatically clean up all shared resources _manager.shutdown() direct_log(f"Process {os.getpid()} Manager shutdown complete") except Exception as e: @@ -340,6 +366,9 @@ def finalize_share_data(): is_multiprocess = None _shared_dicts = None _init_flags = None - _global_lock = None + _storage_lock = None + _internal_lock = None + _pipeline_status_lock = None + _update_flags = None direct_log(f"Process {os.getpid()} storage data finalization complete") From 8fb3670ac78ac126f0c6fb7d967d0f60cfd21c9f Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 02:26:21 +0800 Subject: [PATCH 70/77] Add additional log-related patterns to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 3eb55bd3..6deb14d5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ site/ # Logs / Reports *.log +*.log.* *.logfire *.coverage/ log/ From e20aeada924ed4e0825f6dd85a0bd2018b226a0f Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 02:43:11 +0800 Subject: [PATCH 71/77] docs: add gunicorn deployment guide and update server --- lightrag/api/README.md | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 35062cad..5ffbcdce 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -24,6 +24,8 @@ pip install -e ".[api]" ### Starting API Server with Default Settings +After installing LightRAG with API support, you can start LightRAG by this command: `lightrag-server` + LightRAG requires both LLM and Embedding Model to work together to complete document indexing and querying tasks. LightRAG supports binding to various LLM/Embedding backends: * ollama @@ -92,10 +94,40 @@ LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai-ollama LLM_BINDING_API_KEY=your_api_key Light_server --llm-binding openai --embedding-binding openai # start with ollama llm and ollama embedding (no apikey is needed) -Light_server --llm-binding ollama --embedding-binding ollama +light-server --llm-binding ollama --embedding-binding ollama ``` +### Starting API Server with Gunicorn (Production) + +For production deployments, it's recommended to use Gunicorn as the WSGI server to handle concurrent requests efficiently. LightRAG provides a dedicated Gunicorn startup script that handles shared data initialization, process management, and other critical functionalities. + +```bash +# Start with run_with_gunicorn.py +python run_with_gunicorn.py --workers 4 +``` + +The `--workers` parameter is crucial for performance: + +- Determines how many worker processes Gunicorn will spawn to handle requests +- Each worker can handle concurrent requests using asyncio +- Recommended value is (2 x number_of_cores) + 1 +- For example, on a 4-core machine, use 9 workers: (2 x 4) + 1 = 9 +- Consider your server's memory when setting this value, as each worker consumes memory + +Other important startup parameters: + +- `--host`: Server listening address (default: 0.0.0.0) +- `--port`: Server listening port (default: 9621) +- `--timeout`: Request handling timeout (default: 150 seconds) +- `--log-level`: Logging level (default: INFO) +- `--ssl`: Enable HTTPS +- `--ssl-certfile`: Path to SSL certificate file +- `--ssl-keyfile`: Path to SSL private key file + +The command line parameters and enviroment variable run_with_gunicorn.py is exactly the same as `light-server`. + ### For Azure OpenAI Backend + Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)): ```bash # Change the resource group name, location and OpenAI resource name as needed From 1a5eb200032115d16b5c0e9cad423fdd7561bef6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 04:43:41 +0800 Subject: [PATCH 72/77] Fix history_messages clearing in LightRAG pipeline status initialization --- lightrag/lightrag.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 2cc7883d..1b1afdfc 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -710,11 +710,6 @@ class LightRAG: async with pipeline_status_lock: # Ensure only one worker is processing documents if not pipeline_status.get("busy", False): - # Cleaning history_messages without breaking it as a shared list object - current_history = pipeline_status.get("history_messages", []) - if hasattr(current_history, "clear"): - current_history.clear() - pipeline_status.update( { "busy": True, @@ -725,9 +720,14 @@ class LightRAG: "cur_batch": 0, "request_pending": False, # Clear any previous request "latest_message": "", - "history_messages": current_history, # keep it as a shared list object } ) + # Cleaning history_messages without breaking it as a shared list object + try: + del pipeline_status["history_messages"][:] + except Exception as e: + logger.error(f"Error clearing pipeline history_messages: {str(e)}") + process_documents = True else: # Another process is busy, just set request flag and return From 7124845e558e1b867762a0985c1ab4589f9d3c0f Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 11:09:32 +0800 Subject: [PATCH 73/77] Optimize document processing pipeline with better status tracking & batch handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit โ€ข Add upfront doc processing check โ€ข Optimize pipeline status updates --- lightrag/lightrag.py | 66 +++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 1b1afdfc..8d9c1678 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -706,10 +706,27 @@ class LightRAG: pipeline_status_lock = get_pipeline_status_lock() # Check if another process is already processing the queue - process_documents = False async with pipeline_status_lock: # Ensure only one worker is processing documents if not pipeline_status.get("busy", False): + # ๅ…ˆๆฃ€ๆŸฅๆ˜ฏๅฆๆœ‰้œ€่ฆๅค„็†็š„ๆ–‡ๆกฃ + processing_docs, failed_docs, pending_docs = await asyncio.gather( + self.doc_status.get_docs_by_status(DocStatus.PROCESSING), + self.doc_status.get_docs_by_status(DocStatus.FAILED), + self.doc_status.get_docs_by_status(DocStatus.PENDING), + ) + + to_process_docs: dict[str, DocProcessingStatus] = {} + to_process_docs.update(processing_docs) + to_process_docs.update(failed_docs) + to_process_docs.update(pending_docs) + + # ๅฆ‚ๆžœๆฒกๆœ‰้œ€่ฆๅค„็†็š„ๆ–‡ๆกฃ๏ผŒ็›ดๆŽฅ่ฟ”ๅ›ž๏ผŒไฟ็•™ pipeline_status ไธญ็š„ๅ†…ๅฎนไธๅ˜ + if not to_process_docs: + logger.info("No documents to process") + return + + # ๆœ‰ๆ–‡ๆกฃ้œ€่ฆๅค„็†๏ผŒๆ›ดๆ–ฐ pipeline_status pipeline_status.update( { "busy": True, @@ -723,37 +740,18 @@ class LightRAG: } ) # Cleaning history_messages without breaking it as a shared list object - try: - del pipeline_status["history_messages"][:] - except Exception as e: - logger.error(f"Error clearing pipeline history_messages: {str(e)}") - - process_documents = True + del pipeline_status["history_messages"][:] else: # Another process is busy, just set request flag and return pipeline_status["request_pending"] = True logger.info( "Another process is already processing the document queue. Request queued." ) - - if not process_documents: - return + return try: # Process documents until no more documents or requests while True: - # 1. Get all pending, failed, and abnormally terminated processing documents. - processing_docs, failed_docs, pending_docs = await asyncio.gather( - self.doc_status.get_docs_by_status(DocStatus.PROCESSING), - self.doc_status.get_docs_by_status(DocStatus.FAILED), - self.doc_status.get_docs_by_status(DocStatus.PENDING), - ) - - to_process_docs: dict[str, DocProcessingStatus] = {} - to_process_docs.update(processing_docs) - to_process_docs.update(failed_docs) - to_process_docs.update(pending_docs) - if not to_process_docs: log_message = "All documents have been processed or are duplicates" logger.info(log_message) @@ -761,20 +759,18 @@ class LightRAG: pipeline_status["history_messages"].append(log_message) break - # Update pipeline status with document count (with lock) - pipeline_status["docs"] = len(to_process_docs) - # 2. split docs into chunks, insert chunks, update doc status docs_batches = [ list(to_process_docs.items())[i : i + self.max_parallel_insert] for i in range(0, len(to_process_docs), self.max_parallel_insert) ] - # Update pipeline status with batch information (directly, as it's atomic) - pipeline_status.update({"batchs": len(docs_batches), "cur_batch": 0}) - log_message = f"Number of batches to process: {len(docs_batches)}." logger.info(log_message) + + # Update pipeline status with current batch information + pipeline_status["docs"] += len(to_process_docs) + pipeline_status["batchs"] += len(docs_batches) pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) @@ -782,7 +778,7 @@ class LightRAG: # 3. iterate over batches for batch_idx, docs_batch in enumerate(docs_batches): # Update current batch in pipeline status (directly, as it's atomic) - pipeline_status["cur_batch"] = batch_idx + 1 + pipeline_status["cur_batch"] += 1 async def batch( batch_idx: int, @@ -895,6 +891,18 @@ class LightRAG: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) + # ่Žทๅ–ๆ–ฐ็š„ๅพ…ๅค„็†ๆ–‡ๆกฃ + processing_docs, failed_docs, pending_docs = await asyncio.gather( + self.doc_status.get_docs_by_status(DocStatus.PROCESSING), + self.doc_status.get_docs_by_status(DocStatus.FAILED), + self.doc_status.get_docs_by_status(DocStatus.PENDING), + ) + + to_process_docs = {} + to_process_docs.update(processing_docs) + to_process_docs.update(failed_docs) + to_process_docs.update(pending_docs) + finally: log_message = "Document processing pipeline completed" logger.info(log_message) From 396b8c3347f91d69888836a546164ddfc5f5b044 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 12:44:58 +0800 Subject: [PATCH 74/77] Add psutil to required dependencies for runtime monitoring --- run_with_gunicorn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py index 69124e31..2e4e3cf7 100755 --- a/run_with_gunicorn.py +++ b/run_with_gunicorn.py @@ -16,6 +16,7 @@ def check_and_install_dependencies(): required_packages = [ "gunicorn", "tiktoken", + "psutil", # Add other required packages here ] From fb5f11f59487ee8c65bf234e498dac3759b9df2c Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 18:17:51 +0800 Subject: [PATCH 75/77] Add Gunicorn support for production deployment of LightRAG server - Move gunicorn startup an config files to api package - Create new CLI entry point for Gunicorn mode --- .../api/gunicorn_config.py | 0 lightrag/api/run_with_gunicorn.py | 203 ++++++++++++++++++ setup.py | 1 + 3 files changed, 204 insertions(+) rename gunicorn_config.py => lightrag/api/gunicorn_config.py (100%) create mode 100644 lightrag/api/run_with_gunicorn.py diff --git a/gunicorn_config.py b/lightrag/api/gunicorn_config.py similarity index 100% rename from gunicorn_config.py rename to lightrag/api/gunicorn_config.py diff --git a/lightrag/api/run_with_gunicorn.py b/lightrag/api/run_with_gunicorn.py new file mode 100644 index 00000000..903c5c17 --- /dev/null +++ b/lightrag/api/run_with_gunicorn.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +""" +Start LightRAG server with Gunicorn +""" + +import os +import sys +import signal +import pipmaster as pm +from lightrag.api.utils_api import parse_args, display_splash_screen +from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data + + +def check_and_install_dependencies(): + """Check and install required dependencies""" + required_packages = [ + "gunicorn", + "tiktoken", + "psutil", + # Add other required packages here + ] + + for package in required_packages: + if not pm.is_installed(package): + print(f"Installing {package}...") + pm.install(package) + print(f"{package} installed successfully") + + +# Signal handler for graceful shutdown +def signal_handler(sig, frame): + print("\n\n" + "=" * 80) + print("RECEIVED TERMINATION SIGNAL") + print(f"Process ID: {os.getpid()}") + print("=" * 80 + "\n") + + # Release shared resources + finalize_share_data() + + # Exit with success status + sys.exit(0) + + +def main(): + # Check and install dependencies + check_and_install_dependencies() + + # Register signal handlers for graceful shutdown + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # kill command + + # Parse all arguments using parse_args + args = parse_args(is_uvicorn_mode=False) + + # Display startup information + display_splash_screen(args) + + print("๐Ÿš€ Starting LightRAG with Gunicorn") + print(f"๐Ÿ”„ Worker management: Gunicorn (workers={args.workers})") + print("๐Ÿ” Preloading app: Enabled") + print("๐Ÿ“ Note: Using Gunicorn's preload feature for shared data initialization") + print("\n\n" + "=" * 80) + print("MAIN PROCESS INITIALIZATION") + print(f"Process ID: {os.getpid()}") + print(f"Workers setting: {args.workers}") + print("=" * 80 + "\n") + + # Import Gunicorn's StandaloneApplication + from gunicorn.app.base import BaseApplication + + # Define a custom application class that loads our config + class GunicornApp(BaseApplication): + def __init__(self, app, options=None): + self.options = options or {} + self.application = app + super().__init__() + + def load_config(self): + # Define valid Gunicorn configuration options + valid_options = { + "bind", + "workers", + "worker_class", + "timeout", + "keepalive", + "preload_app", + "errorlog", + "accesslog", + "loglevel", + "certfile", + "keyfile", + "limit_request_line", + "limit_request_fields", + "limit_request_field_size", + "graceful_timeout", + "max_requests", + "max_requests_jitter", + } + + # Special hooks that need to be set separately + special_hooks = { + "on_starting", + "on_reload", + "on_exit", + "pre_fork", + "post_fork", + "pre_exec", + "pre_request", + "post_request", + "worker_init", + "worker_exit", + "nworkers_changed", + "child_exit", + } + + # Import and configure the gunicorn_config module + from lightrag.api import gunicorn_config + + # Set configuration variables in gunicorn_config, prioritizing command line arguments + gunicorn_config.workers = ( + args.workers if args.workers else int(os.getenv("WORKERS", 1)) + ) + + # Bind configuration prioritizes command line arguments + host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") + port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) + gunicorn_config.bind = f"{host}:{port}" + + # Log level configuration prioritizes command line arguments + gunicorn_config.loglevel = ( + args.log_level.lower() + if args.log_level + else os.getenv("LOG_LEVEL", "info") + ) + + # Timeout configuration prioritizes command line arguments + gunicorn_config.timeout = ( + args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) + ) + + # Keepalive configuration + gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) + + # SSL configuration prioritizes command line arguments + if args.ssl or os.getenv("SSL", "").lower() in ( + "true", + "1", + "yes", + "t", + "on", + ): + gunicorn_config.certfile = ( + args.ssl_certfile + if args.ssl_certfile + else os.getenv("SSL_CERTFILE") + ) + gunicorn_config.keyfile = ( + args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") + ) + + # Set configuration options from the module + for key in dir(gunicorn_config): + if key in valid_options: + value = getattr(gunicorn_config, key) + # Skip functions like on_starting and None values + if not callable(value) and value is not None: + self.cfg.set(key, value) + # Set special hooks + elif key in special_hooks: + value = getattr(gunicorn_config, key) + if callable(value): + self.cfg.set(key, value) + + if hasattr(gunicorn_config, "logconfig_dict"): + self.cfg.set( + "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") + ) + + def load(self): + # Import the application + from lightrag.api.lightrag_server import get_application + + return get_application(args) + + # Create the application + app = GunicornApp("") + + # Force workers to be an integer and greater than 1 for multi-process mode + workers_count = int(args.workers) + if workers_count > 1: + # Set a flag to indicate we're in the main process + os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" + initialize_share_data(workers_count) + else: + initialize_share_data(1) + + # Run the application + print("\nStarting Gunicorn with direct Python API...") + app.run() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index c190bd4d..b9063d7d 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ setuptools.setup( entry_points={ "console_scripts": [ "lightrag-server=lightrag.api.lightrag_server:main [api]", + "lightrag-gunicorn=lightrag.api.run_with_gunicorn:main [api]", "lightrag-viewer=lightrag.tools.lightrag_visualizer.graph_visualizer:main [tools]", ], }, From fca6969b0b8918bee9d7601d5e9bf4f5a7b8c053 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 18:33:18 +0800 Subject: [PATCH 76/77] Update Gunicorn startup instructions in API documentation --- lightrag/api/README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 5ffbcdce..8f61f2f6 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -102,8 +102,11 @@ light-server --llm-binding ollama --embedding-binding ollama For production deployments, it's recommended to use Gunicorn as the WSGI server to handle concurrent requests efficiently. LightRAG provides a dedicated Gunicorn startup script that handles shared data initialization, process management, and other critical functionalities. ```bash -# Start with run_with_gunicorn.py -python run_with_gunicorn.py --workers 4 +# Start with lightrag-gunicorn command +lightrag-gunicorn --workers 4 + +# Alternatively, you can use the module directly +python -m lightrag.api.run_with_gunicorn --workers 4 ``` The `--workers` parameter is crucial for performance: From 12fbc5b851ffc7ad9570ca69f2833be1140f29ba Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 19:12:11 +0800 Subject: [PATCH 77/77] Updated env.example with better documentation and clarity - Renamed file from .env.example to env.example --- .env.example => env.example | 2 ++ 1 file changed, 2 insertions(+) rename .env.example => env.example (99%) diff --git a/.env.example b/env.example similarity index 99% rename from .env.example rename to env.example index de9b6452..112676c6 100644 --- a/.env.example +++ b/env.example @@ -1,3 +1,5 @@ +### This is sample file of .env + ### Server Configuration # HOST=0.0.0.0 # PORT=9621