Refactor shared storage to safely handle multi-process initialization and data sharing

• Add namespace initialization check
• Use atomic operations for shared data
This commit is contained in:
yangdx
2025-02-26 18:11:02 +08:00
parent 4eb069d1d6
commit 7d12715f09
2 changed files with 93 additions and 55 deletions

View File

@@ -12,7 +12,11 @@ from lightrag.utils import (
logger, logger,
write_json, write_json,
) )
from .shared_storage import get_namespace_data, get_storage_lock from .shared_storage import (
get_namespace_data,
get_storage_lock,
try_initialize_namespace,
)
@final @final
@@ -24,11 +28,17 @@ class JsonDocStatusStorage(DocStatusStorage):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._storage_lock = get_storage_lock() self._storage_lock = get_storage_lock()
# check need_init must before get_namespace_data
need_init = try_initialize_namespace(self.namespace)
self._data = get_namespace_data(self.namespace) self._data = get_namespace_data(self.namespace)
with self._storage_lock: if need_init:
if not self._data: loaded_data = load_json(self._file_name) or {}
self._data.update(load_json(self._file_name) or {}) with self._storage_lock:
logger.info(f"Loaded document status storage with {len(self._data)} records") self._data.update(loaded_data)
logger.info(
f"Loaded document status storage with {len(loaded_data)} records"
)
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)""" """Return keys that should be processed (not in storage or not successfully processed)"""

View File

@@ -1,30 +1,74 @@
import os
from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing.synchronize import Lock as ProcessLock
from threading import Lock as ThreadLock from threading import Lock as ThreadLock
from multiprocessing import Manager from multiprocessing import Manager
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from lightrag.utils import logger
# 定义类型变量
LockType = Union[ProcessLock, ThreadLock] LockType = Union[ProcessLock, ThreadLock]
# 全局变量
_shared_data: Optional[Dict[str, Any]] = None
_namespace_objects: Optional[Dict[str, Any]] = None
_global_lock: Optional[LockType] = None
is_multiprocess = False is_multiprocess = False
manager = None
def initialize_manager(): _manager = None
"""Initialize manager, only for multiple processes where workers > 1""" _global_lock: Optional[LockType] = None
global manager
if manager is None: # shared data for storage across processes
manager = Manager() _shared_dicts: Optional[Dict[str, Any]] = {}
_share_objects: Optional[Dict[str, Any]] = {}
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
def initialize_share_data():
"""Initialize shared data, only called if multiple processes where workers > 1"""
global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
is_multiprocess = True
logger.info(f"Process {os.getpid()} initializing shared storage")
# Initialize manager
if _manager is None:
_manager = Manager()
logger.info(f"Process {os.getpid()} created manager")
# Create shared dictionaries with manager
_shared_dicts = _manager.dict()
_share_objects = _manager.dict()
_init_flags = _manager.dict() # 使用共享字典存储初始化标志
logger.info(f"Process {os.getpid()} created shared dictionaries")
def try_initialize_namespace(namespace: str) -> bool:
"""
尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
使用共享字典的原子操作确保只有一个进程能成功初始化。
"""
global _init_flags, _manager
if is_multiprocess:
if _init_flags is None:
raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.")
else:
if _init_flags is None:
_init_flags = {}
logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
# 使用全局锁保护共享字典的访问
with _get_global_lock():
# 检查是否已经初始化
if namespace not in _init_flags:
# 设置初始化标志
_init_flags[namespace] = True
logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}")
return True
logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized")
return False
def _get_global_lock() -> LockType: def _get_global_lock() -> LockType:
global _global_lock, is_multiprocess global _global_lock, is_multiprocess, _manager
if _global_lock is None: if _global_lock is None:
if is_multiprocess: if is_multiprocess:
_global_lock = manager.Lock() _global_lock = _manager.Lock() # Use manager for lock
else: else:
_global_lock = ThreadLock() _global_lock = ThreadLock()
@@ -38,56 +82,40 @@ def get_scan_lock() -> LockType:
"""return scan_progress lock for data consistency""" """return scan_progress lock for data consistency"""
return get_storage_lock() return get_storage_lock()
def get_shared_data() -> Dict[str, Any]:
"""
return shared data for all storage types
create mult-process save share data only if need for better performance
"""
global _shared_data, is_multiprocess
if _shared_data is None:
lock = _get_global_lock()
with lock:
if _shared_data is None:
if is_multiprocess:
_shared_data = manager.dict()
else:
_shared_data = {}
return _shared_data
def get_namespace_object(namespace: str) -> Any: def get_namespace_object(namespace: str) -> Any:
"""Get an object for specific namespace""" """Get an object for specific namespace"""
global _namespace_objects, is_multiprocess global _share_objects, is_multiprocess, _manager
if _namespace_objects is None: if is_multiprocess and not _manager:
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
if namespace not in _share_objects:
lock = _get_global_lock() lock = _get_global_lock()
with lock: with lock:
if _namespace_objects is None: if namespace not in _share_objects:
_namespace_objects = {}
if namespace not in _namespace_objects:
lock = _get_global_lock()
with lock:
if namespace not in _namespace_objects:
if is_multiprocess: if is_multiprocess:
_namespace_objects[namespace] = manager.Value('O', None) _share_objects[namespace] = _manager.Value('O', None)
else: else:
_namespace_objects[namespace] = None _share_objects[namespace] = None
return _namespace_objects[namespace] return _share_objects[namespace]
# 移除不再使用的函数
def get_namespace_data(namespace: str) -> Dict[str, Any]: def get_namespace_data(namespace: str) -> Dict[str, Any]:
"""get storage space for specific storage type(namespace)""" """get storage space for specific storage type(namespace)"""
shared_data = get_shared_data() global _shared_dicts, is_multiprocess, _manager
lock = _get_global_lock()
if namespace not in shared_data: if is_multiprocess and not _manager:
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
if namespace not in _shared_dicts:
lock = _get_global_lock()
with lock: with lock:
if namespace not in shared_data: if namespace not in _shared_dicts:
shared_data[namespace] = {} _shared_dicts[namespace] = {}
return shared_data[namespace] return _shared_dicts[namespace]
def get_scan_progress() -> Dict[str, Any]: def get_scan_progress() -> Dict[str, Any]:
"""get storage space for document scanning progress data""" """get storage space for document scanning progress data"""