From 7d12715f098c8b8365a97976452eeab8da9225b1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 26 Feb 2025 18:11:02 +0800 Subject: [PATCH] Refactor shared storage to safely handle multi-process initialization and data sharing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add namespace initialization check • Use atomic operations for shared data --- lightrag/kg/json_doc_status_impl.py | 20 +++-- lightrag/kg/shared_storage.py | 128 +++++++++++++++++----------- 2 files changed, 93 insertions(+), 55 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 58ee3666..2a85c68a 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -12,7 +12,11 @@ from lightrag.utils import ( logger, write_json, ) -from .shared_storage import get_namespace_data, get_storage_lock +from .shared_storage import ( + get_namespace_data, + get_storage_lock, + try_initialize_namespace, +) @final @@ -24,11 +28,17 @@ class JsonDocStatusStorage(DocStatusStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._storage_lock = get_storage_lock() + + # check need_init must before get_namespace_data + need_init = try_initialize_namespace(self.namespace) self._data = get_namespace_data(self.namespace) - with self._storage_lock: - if not self._data: - self._data.update(load_json(self._file_name) or {}) - logger.info(f"Loaded document status storage with {len(self._data)} records") + if need_init: + loaded_data = load_json(self._file_name) or {} + with self._storage_lock: + self._data.update(loaded_data) + logger.info( + f"Loaded document status storage with {len(loaded_data)} records" + ) async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 9de3bb79..27aca9d0 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,30 +1,74 @@ +import os from multiprocessing.synchronize import Lock as ProcessLock from threading import Lock as ThreadLock from multiprocessing import Manager from typing import Any, Dict, Optional, Union +from lightrag.utils import logger -# 定义类型变量 LockType = Union[ProcessLock, ThreadLock] -# 全局变量 -_shared_data: Optional[Dict[str, Any]] = None -_namespace_objects: Optional[Dict[str, Any]] = None -_global_lock: Optional[LockType] = None is_multiprocess = False -manager = None -def initialize_manager(): - """Initialize manager, only for multiple processes where workers > 1""" - global manager - if manager is None: - manager = Manager() +_manager = None +_global_lock: Optional[LockType] = None + +# shared data for storage across processes +_shared_dicts: Optional[Dict[str, Any]] = {} +_share_objects: Optional[Dict[str, Any]] = {} +_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized + +def initialize_share_data(): + """Initialize shared data, only called if multiple processes where workers > 1""" + global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess + is_multiprocess = True + + logger.info(f"Process {os.getpid()} initializing shared storage") + + # Initialize manager + if _manager is None: + _manager = Manager() + logger.info(f"Process {os.getpid()} created manager") + + # Create shared dictionaries with manager + _shared_dicts = _manager.dict() + _share_objects = _manager.dict() + _init_flags = _manager.dict() # 使用共享字典存储初始化标志 + logger.info(f"Process {os.getpid()} created shared dictionaries") + +def try_initialize_namespace(namespace: str) -> bool: + """ + 尝试初始化命名空间。返回True表示当前进程获得了初始化权限。 + 使用共享字典的原子操作确保只有一个进程能成功初始化。 + """ + global _init_flags, _manager + + if is_multiprocess: + if _init_flags is None: + raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.") + else: + if _init_flags is None: + _init_flags = {} + + logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") + + # 使用全局锁保护共享字典的访问 + with _get_global_lock(): + # 检查是否已经初始化 + if namespace not in _init_flags: + # 设置初始化标志 + _init_flags[namespace] = True + logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}") + return True + + logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized") + return False def _get_global_lock() -> LockType: - global _global_lock, is_multiprocess + global _global_lock, is_multiprocess, _manager if _global_lock is None: if is_multiprocess: - _global_lock = manager.Lock() + _global_lock = _manager.Lock() # Use manager for lock else: _global_lock = ThreadLock() @@ -38,56 +82,40 @@ def get_scan_lock() -> LockType: """return scan_progress lock for data consistency""" return get_storage_lock() -def get_shared_data() -> Dict[str, Any]: - """ - return shared data for all storage types - create mult-process save share data only if need for better performance - """ - global _shared_data, is_multiprocess - - if _shared_data is None: - lock = _get_global_lock() - with lock: - if _shared_data is None: - if is_multiprocess: - _shared_data = manager.dict() - else: - _shared_data = {} - - return _shared_data - def get_namespace_object(namespace: str) -> Any: """Get an object for specific namespace""" - global _namespace_objects, is_multiprocess - - if _namespace_objects is None: + global _share_objects, is_multiprocess, _manager + + if is_multiprocess and not _manager: + raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") + + if namespace not in _share_objects: lock = _get_global_lock() with lock: - if _namespace_objects is None: - _namespace_objects = {} - - if namespace not in _namespace_objects: - lock = _get_global_lock() - with lock: - if namespace not in _namespace_objects: + if namespace not in _share_objects: if is_multiprocess: - _namespace_objects[namespace] = manager.Value('O', None) + _share_objects[namespace] = _manager.Value('O', None) else: - _namespace_objects[namespace] = None + _share_objects[namespace] = None - return _namespace_objects[namespace] + return _share_objects[namespace] + +# 移除不再使用的函数 def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" - shared_data = get_shared_data() - lock = _get_global_lock() + global _shared_dicts, is_multiprocess, _manager - if namespace not in shared_data: + if is_multiprocess and not _manager: + raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.") + + if namespace not in _shared_dicts: + lock = _get_global_lock() with lock: - if namespace not in shared_data: - shared_data[namespace] = {} + if namespace not in _shared_dicts: + _shared_dicts[namespace] = {} - return shared_data[namespace] + return _shared_dicts[namespace] def get_scan_progress() -> Dict[str, Any]: """get storage space for document scanning progress data"""