From 7c237920b105643361940d21247339a1c1a3c765 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 08:48:33 +0800 Subject: [PATCH] Refactor shared storage to support both single and multi-process modes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Initialize storage based on worker count • Remove redundant global variable checks • Add explicit mutex initialization • Centralize shared storage initialization • Fix process/thread lock selection logic --- lightrag/api/lightrag_server.py | 12 ++--- lightrag/kg/shared_storage.py | 87 ++++++++++++++------------------- lightrag/lightrag.py | 3 ++ 3 files changed, 43 insertions(+), 59 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 270bbb24..3af8887d 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -96,10 +96,6 @@ def create_app(args): logger.setLevel(getattr(logging, args.log_level)) set_verbose_debug(args.verbose) - from lightrag.kg.shared_storage import is_multiprocess - - logger.info(f"==== Multi-processor mode: {is_multiprocess} ====") - # Verify that bindings are correctly setup if args.llm_binding not in [ "lollms", @@ -422,11 +418,6 @@ def get_application(): args = types.SimpleNamespace(**json.loads(args_json)) - if args.workers > 1: - from lightrag.kg.shared_storage import initialize_share_data - - initialize_share_data() - return create_app(args) @@ -492,6 +483,9 @@ def main(): display_splash_screen(args) + from lightrag.kg.shared_storage import initialize_share_data + initialize_share_data(args.workers) + uvicorn_config = { "app": "lightrag.api.lightrag_server:get_application", "factory": True, diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index bd4c55fe..6b5c07f6 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -7,35 +7,50 @@ from lightrag.utils import logger LockType = Union[ProcessLock, ThreadLock] -is_multiprocess = False - _manager = None -_global_lock: Optional[LockType] = None +_initialized = None +_is_multiprocess = None +is_multiprocess = None # shared data for storage across processes -_shared_dicts: Optional[Dict[str, Any]] = {} -_share_objects: Optional[Dict[str, Any]] = {} +_shared_dicts: Optional[Dict[str, Any]] = None +_share_objects: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized +_global_lock: Optional[LockType] = None -def initialize_share_data(): - """Initialize shared data, only called if multiple processes where workers > 1""" - global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess - is_multiprocess = True - logger.info(f"Process {os.getpid()} initializing shared storage") +def initialize_share_data(workers: int = 1): + """Initialize storage data""" + global _manager, _is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized + + if _initialized and _initialized.value: + is_multiprocess = _is_multiprocess.value + if _is_multiprocess.value: + logger.info(f"Process {os.getpid()} storage data already initialized!") + return - # Initialize manager - if _manager is None: - _manager = Manager() - logger.info(f"Process {os.getpid()} created manager") + _manager = Manager() + _initialized = _manager.Value("b", False) + _is_multiprocess = _manager.Value("b", False) - # Create shared dictionaries with manager - _shared_dicts = _manager.dict() - _share_objects = _manager.dict() - _init_flags = _manager.dict() # 使用共享字典存储初始化标志 - logger.info(f"Process {os.getpid()} created shared dictionaries") + if workers == 1: + _is_multiprocess.value = False + _global_lock = ThreadLock() + _shared_dicts = {} + _share_objects = {} + _init_flags = {} + logger.info(f"Process {os.getpid()} storage data created for Single Process") + else: + _is_multiprocess.value = True + _global_lock = _manager.Lock() + # Create shared dictionaries with manager + _shared_dicts = _manager.dict() + _share_objects = _manager.dict() + _init_flags = _manager.dict() # 使用共享字典存储初始化标志 + logger.info(f"Process {os.getpid()} storage data created for Multiple Process") + is_multiprocess = _is_multiprocess.value def try_initialize_namespace(namespace: str) -> bool: """ @@ -44,7 +59,7 @@ def try_initialize_namespace(namespace: str) -> bool: """ global _init_flags, _manager - if is_multiprocess: + if _is_multiprocess.value: if _init_flags is None: raise RuntimeError( "Shared storage not initialized. Call initialize_share_data() first." @@ -55,17 +70,13 @@ def try_initialize_namespace(namespace: str) -> bool: logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") - # 使用全局锁保护共享字典的访问 - with _get_global_lock(): - # 检查是否已经初始化 + with _global_lock: if namespace not in _init_flags: - # 设置初始化标志 _init_flags[namespace] = True logger.info( f"Process {os.getpid()} ready to initialize namespace {namespace}" ) return True - logger.info( f"Process {os.getpid()} found namespace {namespace} already initialized" ) @@ -73,14 +84,6 @@ def try_initialize_namespace(namespace: str) -> bool: def _get_global_lock() -> LockType: - global _global_lock, is_multiprocess, _manager - - if _global_lock is None: - if is_multiprocess: - _global_lock = _manager.Lock() # Use manager for lock - else: - _global_lock = ThreadLock() - return _global_lock @@ -96,36 +99,20 @@ def get_scan_lock() -> LockType: def get_namespace_object(namespace: str) -> Any: """Get an object for specific namespace""" - global _share_objects, is_multiprocess, _manager - - if is_multiprocess and not _manager: - raise RuntimeError( - "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first." - ) if namespace not in _share_objects: lock = _get_global_lock() with lock: if namespace not in _share_objects: - if is_multiprocess: + if _is_multiprocess.value: _share_objects[namespace] = _manager.Value("O", None) else: _share_objects[namespace] = None return _share_objects[namespace] - -# 移除不再使用的函数 - - def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" - global _shared_dicts, is_multiprocess, _manager - - if is_multiprocess and not _manager: - raise RuntimeError( - "Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first." - ) if namespace not in _shared_dicts: lock = _get_global_lock() diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 46638243..08ca202f 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -267,6 +267,9 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): + from lightrag.kg.shared_storage import initialize_share_data + initialize_share_data() + os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path, self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}")