From ff9cb2138da353f79d195d389bd57de9a346711e Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 24 Mar 2025 02:11:59 +0800 Subject: [PATCH] Ensure thread safety in storage update callbacks - Added storage lock in index_done_callback - Fixed potential race conditions --- lightrag/api/utils_api.py | 4 ++-- lightrag/kg/faiss_impl.py | 25 +++++++++++++------------ lightrag/kg/nano_vector_db_impl.py | 27 ++++++++++++++------------- lightrag/kg/networkx_impl.py | 25 +++++++++++++------------ 4 files changed, 42 insertions(+), 39 deletions(-) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index ddc0554c..a762b28b 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -365,7 +365,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: args.vector_storage = get_env_value( "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE ) - + # Get MAX_PARALLEL_INSERT from environment args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int) @@ -397,7 +397,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: args.enable_llm_cache_for_extract = get_env_value( "ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool ) - + # Inject LLM temperature configuration args.temperature = get_env_value("TEMPERATURE", 0.5, float) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 57b0cae0..e94ecbe8 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -343,18 +343,19 @@ class FaissVectorDBStorage(BaseVectorStorage): self._id_to_meta = {} async def index_done_callback(self) -> None: - # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: - # Storage was updated by another process, reload data instead of saving - logger.warning( - f"Storage for FAISS {self.namespace} was updated by another process, reloading..." - ) - async with self._storage_lock: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta = {} - self._load_faiss_index() - self.storage_updated.value = False - return False # Return error + async with self._storage_lock: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Storage for FAISS {self.namespace} was updated by another process, reloading..." + ) + async with self._storage_lock: + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + self.storage_updated.value = False + return False # Return error # Acquire lock and perform persistence async with self._storage_lock: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 4f739091..abd1f0ae 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -206,19 +206,20 @@ class NanoVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> bool: """Save data to disk""" - # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: - # Storage was updated by another process, reload data instead of saving - logger.warning( - f"Storage for {self.namespace} was updated by another process, reloading..." - ) - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, - storage_file=self._client_file_name, - ) - # Reset update flag - self.storage_updated.value = False - return False # Return error + async with self._storage_lock: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Storage for {self.namespace} was updated by another process, reloading..." + ) + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, + storage_file=self._client_file_name, + ) + # Reset update flag + self.storage_updated.value = False + return False # Return error # Acquire lock and perform persistence async with self._storage_lock: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 7026cf6d..e21d2ed9 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -401,18 +401,19 @@ class NetworkXStorage(BaseGraphStorage): async def index_done_callback(self) -> bool: """Save data to disk""" - # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: - # Storage was updated by another process, reload data instead of saving - logger.warning( - f"Graph for {self.namespace} was updated by another process, reloading..." - ) - self._graph = ( - NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() - ) - # Reset update flag - self.storage_updated.value = False - return False # Return error + async with self._storage_lock: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Graph for {self.namespace} was updated by another process, reloading..." + ) + self._graph = ( + NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph() + ) + # Reset update flag + self.storage_updated.value = False + return False # Return error # Acquire lock and perform persistence async with self._storage_lock: