Refactor shared storage to safely handle multi-process initialization and data sharing
• Add namespace initialization check • Use atomic operations for shared data
This commit is contained in:
@@ -12,7 +12,11 @@ from lightrag.utils import (
|
||||
logger,
|
||||
write_json,
|
||||
)
|
||||
from .shared_storage import get_namespace_data, get_storage_lock
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_lock,
|
||||
try_initialize_namespace,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@@ -24,11 +28,17 @@ class JsonDocStatusStorage(DocStatusStorage):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
# check need_init must before get_namespace_data
|
||||
need_init = try_initialize_namespace(self.namespace)
|
||||
self._data = get_namespace_data(self.namespace)
|
||||
with self._storage_lock:
|
||||
if not self._data:
|
||||
self._data.update(load_json(self._file_name) or {})
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
if need_init:
|
||||
loaded_data = load_json(self._file_name) or {}
|
||||
with self._storage_lock:
|
||||
self._data.update(loaded_data)
|
||||
logger.info(
|
||||
f"Loaded document status storage with {len(loaded_data)} records"
|
||||
)
|
||||
|
||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
|
@@ -1,30 +1,74 @@
|
||||
import os
|
||||
from multiprocessing.synchronize import Lock as ProcessLock
|
||||
from threading import Lock as ThreadLock
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from lightrag.utils import logger
|
||||
|
||||
# 定义类型变量
|
||||
LockType = Union[ProcessLock, ThreadLock]
|
||||
|
||||
# 全局变量
|
||||
_shared_data: Optional[Dict[str, Any]] = None
|
||||
_namespace_objects: Optional[Dict[str, Any]] = None
|
||||
_global_lock: Optional[LockType] = None
|
||||
is_multiprocess = False
|
||||
manager = None
|
||||
|
||||
def initialize_manager():
|
||||
"""Initialize manager, only for multiple processes where workers > 1"""
|
||||
global manager
|
||||
if manager is None:
|
||||
manager = Manager()
|
||||
_manager = None
|
||||
_global_lock: Optional[LockType] = None
|
||||
|
||||
# shared data for storage across processes
|
||||
_shared_dicts: Optional[Dict[str, Any]] = {}
|
||||
_share_objects: Optional[Dict[str, Any]] = {}
|
||||
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
||||
|
||||
def initialize_share_data():
|
||||
"""Initialize shared data, only called if multiple processes where workers > 1"""
|
||||
global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
|
||||
is_multiprocess = True
|
||||
|
||||
logger.info(f"Process {os.getpid()} initializing shared storage")
|
||||
|
||||
# Initialize manager
|
||||
if _manager is None:
|
||||
_manager = Manager()
|
||||
logger.info(f"Process {os.getpid()} created manager")
|
||||
|
||||
# Create shared dictionaries with manager
|
||||
_shared_dicts = _manager.dict()
|
||||
_share_objects = _manager.dict()
|
||||
_init_flags = _manager.dict() # 使用共享字典存储初始化标志
|
||||
logger.info(f"Process {os.getpid()} created shared dictionaries")
|
||||
|
||||
def try_initialize_namespace(namespace: str) -> bool:
|
||||
"""
|
||||
尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
|
||||
使用共享字典的原子操作确保只有一个进程能成功初始化。
|
||||
"""
|
||||
global _init_flags, _manager
|
||||
|
||||
if is_multiprocess:
|
||||
if _init_flags is None:
|
||||
raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.")
|
||||
else:
|
||||
if _init_flags is None:
|
||||
_init_flags = {}
|
||||
|
||||
logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
|
||||
|
||||
# 使用全局锁保护共享字典的访问
|
||||
with _get_global_lock():
|
||||
# 检查是否已经初始化
|
||||
if namespace not in _init_flags:
|
||||
# 设置初始化标志
|
||||
_init_flags[namespace] = True
|
||||
logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}")
|
||||
return True
|
||||
|
||||
logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized")
|
||||
return False
|
||||
|
||||
def _get_global_lock() -> LockType:
|
||||
global _global_lock, is_multiprocess
|
||||
global _global_lock, is_multiprocess, _manager
|
||||
|
||||
if _global_lock is None:
|
||||
if is_multiprocess:
|
||||
_global_lock = manager.Lock()
|
||||
_global_lock = _manager.Lock() # Use manager for lock
|
||||
else:
|
||||
_global_lock = ThreadLock()
|
||||
|
||||
@@ -38,56 +82,40 @@ def get_scan_lock() -> LockType:
|
||||
"""return scan_progress lock for data consistency"""
|
||||
return get_storage_lock()
|
||||
|
||||
def get_shared_data() -> Dict[str, Any]:
|
||||
"""
|
||||
return shared data for all storage types
|
||||
create mult-process save share data only if need for better performance
|
||||
"""
|
||||
global _shared_data, is_multiprocess
|
||||
|
||||
if _shared_data is None:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if _shared_data is None:
|
||||
if is_multiprocess:
|
||||
_shared_data = manager.dict()
|
||||
else:
|
||||
_shared_data = {}
|
||||
|
||||
return _shared_data
|
||||
|
||||
def get_namespace_object(namespace: str) -> Any:
|
||||
"""Get an object for specific namespace"""
|
||||
global _namespace_objects, is_multiprocess
|
||||
|
||||
if _namespace_objects is None:
|
||||
global _share_objects, is_multiprocess, _manager
|
||||
|
||||
if is_multiprocess and not _manager:
|
||||
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
|
||||
|
||||
if namespace not in _share_objects:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if _namespace_objects is None:
|
||||
_namespace_objects = {}
|
||||
|
||||
if namespace not in _namespace_objects:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if namespace not in _namespace_objects:
|
||||
if namespace not in _share_objects:
|
||||
if is_multiprocess:
|
||||
_namespace_objects[namespace] = manager.Value('O', None)
|
||||
_share_objects[namespace] = _manager.Value('O', None)
|
||||
else:
|
||||
_namespace_objects[namespace] = None
|
||||
_share_objects[namespace] = None
|
||||
|
||||
return _namespace_objects[namespace]
|
||||
return _share_objects[namespace]
|
||||
|
||||
# 移除不再使用的函数
|
||||
|
||||
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||
"""get storage space for specific storage type(namespace)"""
|
||||
shared_data = get_shared_data()
|
||||
lock = _get_global_lock()
|
||||
global _shared_dicts, is_multiprocess, _manager
|
||||
|
||||
if namespace not in shared_data:
|
||||
if is_multiprocess and not _manager:
|
||||
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
|
||||
|
||||
if namespace not in _shared_dicts:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if namespace not in shared_data:
|
||||
shared_data[namespace] = {}
|
||||
if namespace not in _shared_dicts:
|
||||
_shared_dicts[namespace] = {}
|
||||
|
||||
return shared_data[namespace]
|
||||
return _shared_dicts[namespace]
|
||||
|
||||
def get_scan_progress() -> Dict[str, Any]:
|
||||
"""get storage space for document scanning progress data"""
|
||||
|
Reference in New Issue
Block a user