From a721421bd8e8a59e1f653695d67748e5999abacd Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 01:49:26 +0800 Subject: [PATCH] Add async support and update flag mechanism for shared storage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Use asyncio.Lock instead of thread lock for single process mode • Add storage update notification system --- lightrag/kg/shared_storage.py | 86 +++++++++++++++++++++++++++++------ 1 file changed, 73 insertions(+), 13 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 4cad25fa..7ac0d625 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,7 +1,7 @@ import os import sys +import asyncio from multiprocessing.synchronize import Lock as ProcessLock -from threading import Lock as ThreadLock from multiprocessing import Manager from typing import Any, Dict, Optional, Union @@ -15,16 +15,18 @@ def direct_log(message, level="INFO"): print(f"{level}: {message}", file=sys.stderr, flush=True) -LockType = Union[ProcessLock, ThreadLock] +LockType = Union[ProcessLock, asyncio.Lock] +is_multiprocess = None +_workers = None _manager = None _initialized = None -is_multiprocess = None _global_lock: Optional[LockType] = None # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized +_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated def initialize_share_data(workers: int = 1): @@ -47,12 +49,14 @@ def initialize_share_data(workers: int = 1): """ global \ _manager, \ + _workers, \ is_multiprocess, \ is_multiprocess, \ _global_lock, \ _shared_dicts, \ _init_flags, \ - _initialized + _initialized, \ + _update_flags # Check if already initialized if _initialized: @@ -62,20 +66,23 @@ def initialize_share_data(workers: int = 1): return _manager = Manager() + _workers = workers if workers > 1: is_multiprocess = True _global_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() + _update_flags = _manager.dict() direct_log( f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" ) else: is_multiprocess = False - _global_lock = ThreadLock() + _global_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} + _update_flags = {} direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") # Mark as initialized @@ -86,7 +93,6 @@ def initialize_share_data(workers: int = 1): # Create a shared list object for history_messages history_messages = _manager.list() if is_multiprocess else [] - pipeline_namespace.update( { "busy": False, # Control concurrent processes @@ -102,6 +108,58 @@ def initialize_share_data(workers: int = 1): ) +async def get_update_flags(namespace: str): + """ + Create a updated flags of a specific namespace. + Caller must store the dict object locally and use it to determine when to update the storage. + """ + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + if is_multiprocess: + with _global_lock: + if namespace not in _update_flags: + if _manager is not None: + _update_flags[namespace] = _manager.list() + direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") + + if _manager is not None: + new_update_flag = _manager.Value('b', False) + _update_flags[namespace].append(new_update_flag) + return new_update_flag + else: + async with _global_lock: + if namespace not in _update_flags: + _update_flags[namespace] = [] + direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]") + + new_update_flag = False + _update_flags[namespace].append(new_update_flag) + return new_update_flag + +async def set_update_flag(namespace: str): + """Set all update flag of namespace to indicate storage needs updating""" + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + if is_multiprocess: + with _global_lock: + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for multiprocess mode + for i in range(len(_update_flags[namespace])): + _update_flags[namespace][i].value = True + else: + async with _global_lock: + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for single process mode + for i in range(len(_update_flags[namespace])): + _update_flags[namespace][i] = True + + def try_initialize_namespace(namespace: str) -> bool: """ Try to initialize a namespace. Returns True if the current process gets initialization permission. @@ -129,7 +187,7 @@ def get_storage_lock() -> LockType: return _global_lock -def get_namespace_data(namespace: str) -> Dict[str, Any]: +async def get_namespace_data(namespace: str) -> Dict[str, Any]: """get storage space for specific storage type(namespace)""" if _shared_dicts is None: direct_log( @@ -138,12 +196,14 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]: ) raise ValueError("Shared dictionaries not initialized") - lock = get_storage_lock() - with lock: - if namespace not in _shared_dicts: - if is_multiprocess and _manager is not None: - _shared_dicts[namespace] = _manager.dict() - else: + if is_multiprocess: + with _global_lock: + if namespace not in _shared_dicts: + if _manager is not None: + _shared_dicts[namespace] = _manager.dict() + else: + async with _global_lock: + if namespace not in _shared_dicts: _shared_dicts[namespace] = {} return _shared_dicts[namespace]