From 15e060f8544ec3e31dd150f7bd96eb271a8f136c Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 25 Mar 2025 10:48:15 +0800 Subject: [PATCH] Fix share storage update status handling problem of in memeory storage --- lightrag/kg/faiss_impl.py | 17 ++------- lightrag/kg/nano_vector_db_impl.py | 19 +++------- lightrag/kg/networkx_impl.py | 17 ++------- lightrag/kg/shared_storage.py | 60 +++++++++++++----------------- 4 files changed, 39 insertions(+), 74 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index e94ecbe8..b8176037 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -19,7 +19,6 @@ from .shared_storage import ( get_storage_lock, get_update_flag, set_all_update_flags, - is_multiprocess, ) @@ -73,9 +72,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Acquire lock to prevent concurrent read and write async with self._storage_lock: # Check if storage was updated by another process - if (is_multiprocess and self.storage_updated.value) or ( - not is_multiprocess and self.storage_updated - ): + if self.storage_updated.value: logger.info( f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process" ) @@ -83,10 +80,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} self._load_faiss_index() - if is_multiprocess: - self.storage_updated.value = False - else: - self.storage_updated = False + self.storage_updated.value = False return self._index async def upsert(self, data: dict[str, dict[str, Any]]) -> None: @@ -345,7 +339,7 @@ class FaissVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: async with self._storage_lock: # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: + if self.storage_updated.value: # Storage was updated by another process, reload data instead of saving logger.warning( f"Storage for FAISS {self.namespace} was updated by another process, reloading..." @@ -365,10 +359,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Notify other processes that data has been updated await set_all_update_flags(self.namespace) # Reset own update flag to avoid self-reloading - if is_multiprocess: - self.storage_updated.value = False - else: - self.storage_updated = False + self.storage_updated.value = False except Exception as e: logger.error(f"Error saving FAISS index for {self.namespace}: {e}") return False # Return error diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index abd1f0ae..553ba0b2 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -20,7 +20,6 @@ from .shared_storage import ( get_storage_lock, get_update_flag, set_all_update_flags, - is_multiprocess, ) @@ -57,16 +56,14 @@ class NanoVectorDBStorage(BaseVectorStorage): # Get the update flag for cross-process update notification self.storage_updated = await get_update_flag(self.namespace) # Get the storage lock for use in other methods - self._storage_lock = get_storage_lock() + self._storage_lock = get_storage_lock(enable_logging=False) async def _get_client(self): """Check if the storage should be reloaded""" # Acquire lock to prevent concurrent read and write async with self._storage_lock: # Check if data needs to be reloaded - if (is_multiprocess and self.storage_updated.value) or ( - not is_multiprocess and self.storage_updated - ): + if self.storage_updated.value: logger.info( f"Process {os.getpid()} reloading {self.namespace} due to update by another process" ) @@ -76,10 +73,7 @@ class NanoVectorDBStorage(BaseVectorStorage): storage_file=self._client_file_name, ) # Reset update flag - if is_multiprocess: - self.storage_updated.value = False - else: - self.storage_updated = False + self.storage_updated.value = False return self._client @@ -208,7 +202,7 @@ class NanoVectorDBStorage(BaseVectorStorage): """Save data to disk""" async with self._storage_lock: # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: + if self.storage_updated.value: # Storage was updated by another process, reload data instead of saving logger.warning( f"Storage for {self.namespace} was updated by another process, reloading..." @@ -229,10 +223,7 @@ class NanoVectorDBStorage(BaseVectorStorage): # Notify other processes that data has been updated await set_all_update_flags(self.namespace) # Reset own update flag to avoid self-reloading - if is_multiprocess: - self.storage_updated.value = False - else: - self.storage_updated = False + self.storage_updated.value = False return True # Return success except Exception as e: logger.error(f"Error saving data for {self.namespace}: {e}") diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index e21d2ed9..324fe7af 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -21,7 +21,6 @@ from .shared_storage import ( get_storage_lock, get_update_flag, set_all_update_flags, - is_multiprocess, ) MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) @@ -110,9 +109,7 @@ class NetworkXStorage(BaseGraphStorage): # Acquire lock to prevent concurrent read and write async with self._storage_lock: # Check if data needs to be reloaded - if (is_multiprocess and self.storage_updated.value) or ( - not is_multiprocess and self.storage_updated - ): + if self.storage_updated.value: logger.info( f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process" ) @@ -121,10 +118,7 @@ class NetworkXStorage(BaseGraphStorage): NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() ) # Reset update flag - if is_multiprocess: - self.storage_updated.value = False - else: - self.storage_updated = False + self.storage_updated.value = False return self._graph @@ -403,7 +397,7 @@ class NetworkXStorage(BaseGraphStorage): """Save data to disk""" async with self._storage_lock: # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: + if self.storage_updated.value: # Storage was updated by another process, reload data instead of saving logger.warning( f"Graph for {self.namespace} was updated by another process, reloading..." @@ -423,10 +417,7 @@ class NetworkXStorage(BaseGraphStorage): # Notify other processes that data has been updated await set_all_update_flags(self.namespace) # Reset own update flag to avoid self-reloading - if is_multiprocess: - self.storage_updated.value = False - else: - self.storage_updated = False + self.storage_updated.value = False return True # Return success except Exception as e: logger.error(f"Error saving graph for {self.namespace}: {e}") diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index e26645c8..4bdbce99 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -24,7 +24,7 @@ def direct_log(message, level="INFO", enable_output: bool = True): T = TypeVar("T") LockType = Union[ProcessLock, asyncio.Lock] -is_multiprocess = None +_is_multiprocess = None _workers = None _manager = None _initialized = None @@ -218,10 +218,10 @@ class UnifiedLock(Generic[T]): def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - async_lock = _async_locks.get("internal_lock") if is_multiprocess else None + async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None return UnifiedLock( lock=_internal_lock, - is_async=not is_multiprocess, + is_async=not _is_multiprocess, name="internal_lock", enable_logging=enable_logging, async_lock=async_lock, @@ -230,10 +230,10 @@ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - async_lock = _async_locks.get("storage_lock") if is_multiprocess else None + async_lock = _async_locks.get("storage_lock") if _is_multiprocess else None return UnifiedLock( lock=_storage_lock, - is_async=not is_multiprocess, + is_async=not _is_multiprocess, name="storage_lock", enable_logging=enable_logging, async_lock=async_lock, @@ -242,10 +242,10 @@ def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - async_lock = _async_locks.get("pipeline_status_lock") if is_multiprocess else None + async_lock = _async_locks.get("pipeline_status_lock") if _is_multiprocess else None return UnifiedLock( lock=_pipeline_status_lock, - is_async=not is_multiprocess, + is_async=not _is_multiprocess, name="pipeline_status_lock", enable_logging=enable_logging, async_lock=async_lock, @@ -254,10 +254,10 @@ def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: """return unified graph database lock for ensuring atomic operations""" - async_lock = _async_locks.get("graph_db_lock") if is_multiprocess else None + async_lock = _async_locks.get("graph_db_lock") if _is_multiprocess else None return UnifiedLock( lock=_graph_db_lock, - is_async=not is_multiprocess, + is_async=not _is_multiprocess, name="graph_db_lock", enable_logging=enable_logging, async_lock=async_lock, @@ -266,10 +266,10 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: """return unified data initialization lock for ensuring atomic data initialization""" - async_lock = _async_locks.get("data_init_lock") if is_multiprocess else None + async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None return UnifiedLock( lock=_data_init_lock, - is_async=not is_multiprocess, + is_async=not _is_multiprocess, name="data_init_lock", enable_logging=enable_logging, async_lock=async_lock, @@ -297,7 +297,7 @@ def initialize_share_data(workers: int = 1): global \ _manager, \ _workers, \ - is_multiprocess, \ + _is_multiprocess, \ _storage_lock, \ _internal_lock, \ _pipeline_status_lock, \ @@ -312,14 +312,14 @@ def initialize_share_data(workers: int = 1): # Check if already initialized if _initialized: direct_log( - f"Process {os.getpid()} Shared-Data already initialized (multiprocess={is_multiprocess})" + f"Process {os.getpid()} Shared-Data already initialized (multiprocess={_is_multiprocess})" ) return _workers = workers if workers > 1: - is_multiprocess = True + _is_multiprocess = True _manager = Manager() _internal_lock = _manager.Lock() _storage_lock = _manager.Lock() @@ -343,7 +343,7 @@ def initialize_share_data(workers: int = 1): f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" ) else: - is_multiprocess = False + _is_multiprocess = False _internal_lock = asyncio.Lock() _storage_lock = asyncio.Lock() _pipeline_status_lock = asyncio.Lock() @@ -372,7 +372,7 @@ async def initialize_pipeline_status(): return # Create a shared list object for history_messages - history_messages = _manager.list() if is_multiprocess else [] + history_messages = _manager.list() if _is_multiprocess else [] pipeline_namespace.update( { "autoscanned": False, # Auto-scan started @@ -401,7 +401,7 @@ async def get_update_flag(namespace: str): async with get_internal_lock(): if namespace not in _update_flags: - if is_multiprocess and _manager is not None: + if _is_multiprocess and _manager is not None: _update_flags[namespace] = _manager.list() else: _update_flags[namespace] = [] @@ -409,7 +409,7 @@ async def get_update_flag(namespace: str): f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]" ) - if is_multiprocess and _manager is not None: + if _is_multiprocess and _manager is not None: new_update_flag = _manager.Value("b", False) else: # Create a simple mutable object to store boolean value for compatibility with mutiprocess @@ -434,11 +434,7 @@ async def set_all_update_flags(namespace: str): raise ValueError(f"Namespace {namespace} not found in update flags") # Update flags for both modes for i in range(len(_update_flags[namespace])): - if is_multiprocess: - _update_flags[namespace][i].value = True - else: - # Use .value attribute instead of direct assignment - _update_flags[namespace][i].value = True + _update_flags[namespace][i].value = True async def clear_all_update_flags(namespace: str): @@ -452,11 +448,7 @@ async def clear_all_update_flags(namespace: str): raise ValueError(f"Namespace {namespace} not found in update flags") # Update flags for both modes for i in range(len(_update_flags[namespace])): - if is_multiprocess: - _update_flags[namespace][i].value = False - else: - # Use .value attribute instead of direct assignment - _update_flags[namespace][i].value = False + _update_flags[namespace][i].value = False async def get_all_update_flags_status() -> Dict[str, list]: @@ -474,7 +466,7 @@ async def get_all_update_flags_status() -> Dict[str, list]: for namespace, flags in _update_flags.items(): worker_statuses = [] for flag in flags: - if is_multiprocess: + if _is_multiprocess: worker_statuses.append(flag.value) else: worker_statuses.append(flag) @@ -518,7 +510,7 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]: async with get_internal_lock(): if namespace not in _shared_dicts: - if is_multiprocess and _manager is not None: + if _is_multiprocess and _manager is not None: _shared_dicts[namespace] = _manager.dict() else: _shared_dicts[namespace] = {} @@ -538,7 +530,7 @@ def finalize_share_data(): """ global \ _manager, \ - is_multiprocess, \ + _is_multiprocess, \ _storage_lock, \ _internal_lock, \ _pipeline_status_lock, \ @@ -558,11 +550,11 @@ def finalize_share_data(): return direct_log( - f"Process {os.getpid()} finalizing storage data (multiprocess={is_multiprocess})" + f"Process {os.getpid()} finalizing storage data (multiprocess={_is_multiprocess})" ) # In multi-process mode, shut down the Manager - if is_multiprocess and _manager is not None: + if _is_multiprocess and _manager is not None: try: # Clear shared resources before shutting down Manager if _shared_dicts is not None: @@ -604,7 +596,7 @@ def finalize_share_data(): # Reset global variables _manager = None _initialized = None - is_multiprocess = None + _is_multiprocess = None _shared_dicts = None _init_flags = None _storage_lock = None