Refactor shared storage to support both single and multi-process modes

• Initialize storage based on worker count
• Remove redundant global variable checks
• Add explicit mutex initialization
• Centralize shared storage initialization
• Fix process/thread lock selection logic
This commit is contained in:
yangdx
2025-02-27 08:48:33 +08:00
parent 7436c06f6c
commit 7c237920b1
3 changed files with 43 additions and 59 deletions

View File

@@ -96,10 +96,6 @@ def create_app(args):
logger.setLevel(getattr(logging, args.log_level)) logger.setLevel(getattr(logging, args.log_level))
set_verbose_debug(args.verbose) set_verbose_debug(args.verbose)
from lightrag.kg.shared_storage import is_multiprocess
logger.info(f"==== Multi-processor mode: {is_multiprocess} ====")
# Verify that bindings are correctly setup # Verify that bindings are correctly setup
if args.llm_binding not in [ if args.llm_binding not in [
"lollms", "lollms",
@@ -422,11 +418,6 @@ def get_application():
args = types.SimpleNamespace(**json.loads(args_json)) args = types.SimpleNamespace(**json.loads(args_json))
if args.workers > 1:
from lightrag.kg.shared_storage import initialize_share_data
initialize_share_data()
return create_app(args) return create_app(args)
@@ -492,6 +483,9 @@ def main():
display_splash_screen(args) display_splash_screen(args)
from lightrag.kg.shared_storage import initialize_share_data
initialize_share_data(args.workers)
uvicorn_config = { uvicorn_config = {
"app": "lightrag.api.lightrag_server:get_application", "app": "lightrag.api.lightrag_server:get_application",
"factory": True, "factory": True,

View File

@@ -7,35 +7,50 @@ from lightrag.utils import logger
LockType = Union[ProcessLock, ThreadLock] LockType = Union[ProcessLock, ThreadLock]
is_multiprocess = False
_manager = None _manager = None
_global_lock: Optional[LockType] = None _initialized = None
_is_multiprocess = None
is_multiprocess = None
# shared data for storage across processes # shared data for storage across processes
_shared_dicts: Optional[Dict[str, Any]] = {} _shared_dicts: Optional[Dict[str, Any]] = None
_share_objects: Optional[Dict[str, Any]] = {} _share_objects: Optional[Dict[str, Any]] = None
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
_global_lock: Optional[LockType] = None
def initialize_share_data():
"""Initialize shared data, only called if multiple processes where workers > 1"""
global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
is_multiprocess = True
logger.info(f"Process {os.getpid()} initializing shared storage") def initialize_share_data(workers: int = 1):
"""Initialize storage data"""
global _manager, _is_multiprocess, is_multiprocess, _global_lock, _shared_dicts, _share_objects, _init_flags, _initialized
# Initialize manager if _initialized and _initialized.value:
if _manager is None: is_multiprocess = _is_multiprocess.value
_manager = Manager() if _is_multiprocess.value:
logger.info(f"Process {os.getpid()} created manager") logger.info(f"Process {os.getpid()} storage data already initialized!")
return
# Create shared dictionaries with manager _manager = Manager()
_shared_dicts = _manager.dict() _initialized = _manager.Value("b", False)
_share_objects = _manager.dict() _is_multiprocess = _manager.Value("b", False)
_init_flags = _manager.dict() # 使用共享字典存储初始化标志
logger.info(f"Process {os.getpid()} created shared dictionaries")
if workers == 1:
_is_multiprocess.value = False
_global_lock = ThreadLock()
_shared_dicts = {}
_share_objects = {}
_init_flags = {}
logger.info(f"Process {os.getpid()} storage data created for Single Process")
else:
_is_multiprocess.value = True
_global_lock = _manager.Lock()
# Create shared dictionaries with manager
_shared_dicts = _manager.dict()
_share_objects = _manager.dict()
_init_flags = _manager.dict() # 使用共享字典存储初始化标志
logger.info(f"Process {os.getpid()} storage data created for Multiple Process")
is_multiprocess = _is_multiprocess.value
def try_initialize_namespace(namespace: str) -> bool: def try_initialize_namespace(namespace: str) -> bool:
""" """
@@ -44,7 +59,7 @@ def try_initialize_namespace(namespace: str) -> bool:
""" """
global _init_flags, _manager global _init_flags, _manager
if is_multiprocess: if _is_multiprocess.value:
if _init_flags is None: if _init_flags is None:
raise RuntimeError( raise RuntimeError(
"Shared storage not initialized. Call initialize_share_data() first." "Shared storage not initialized. Call initialize_share_data() first."
@@ -55,17 +70,13 @@ def try_initialize_namespace(namespace: str) -> bool:
logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}") logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
# 使用全局锁保护共享字典的访问 with _global_lock:
with _get_global_lock():
# 检查是否已经初始化
if namespace not in _init_flags: if namespace not in _init_flags:
# 设置初始化标志
_init_flags[namespace] = True _init_flags[namespace] = True
logger.info( logger.info(
f"Process {os.getpid()} ready to initialize namespace {namespace}" f"Process {os.getpid()} ready to initialize namespace {namespace}"
) )
return True return True
logger.info( logger.info(
f"Process {os.getpid()} found namespace {namespace} already initialized" f"Process {os.getpid()} found namespace {namespace} already initialized"
) )
@@ -73,14 +84,6 @@ def try_initialize_namespace(namespace: str) -> bool:
def _get_global_lock() -> LockType: def _get_global_lock() -> LockType:
global _global_lock, is_multiprocess, _manager
if _global_lock is None:
if is_multiprocess:
_global_lock = _manager.Lock() # Use manager for lock
else:
_global_lock = ThreadLock()
return _global_lock return _global_lock
@@ -96,36 +99,20 @@ def get_scan_lock() -> LockType:
def get_namespace_object(namespace: str) -> Any: def get_namespace_object(namespace: str) -> Any:
"""Get an object for specific namespace""" """Get an object for specific namespace"""
global _share_objects, is_multiprocess, _manager
if is_multiprocess and not _manager:
raise RuntimeError(
"Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
)
if namespace not in _share_objects: if namespace not in _share_objects:
lock = _get_global_lock() lock = _get_global_lock()
with lock: with lock:
if namespace not in _share_objects: if namespace not in _share_objects:
if is_multiprocess: if _is_multiprocess.value:
_share_objects[namespace] = _manager.Value("O", None) _share_objects[namespace] = _manager.Value("O", None)
else: else:
_share_objects[namespace] = None _share_objects[namespace] = None
return _share_objects[namespace] return _share_objects[namespace]
# 移除不再使用的函数
def get_namespace_data(namespace: str) -> Dict[str, Any]: def get_namespace_data(namespace: str) -> Dict[str, Any]:
"""get storage space for specific storage type(namespace)""" """get storage space for specific storage type(namespace)"""
global _shared_dicts, is_multiprocess, _manager
if is_multiprocess and not _manager:
raise RuntimeError(
"Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
)
if namespace not in _shared_dicts: if namespace not in _shared_dicts:
lock = _get_global_lock() lock = _get_global_lock()

View File

@@ -267,6 +267,9 @@ class LightRAG:
_storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
def __post_init__(self): def __post_init__(self):
from lightrag.kg.shared_storage import initialize_share_data
initialize_share_data()
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path, self.log_level) set_logger(self.log_file_path, self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")