Ensure thread safety in storage update callbacks

- Added storage lock in index_done_callback
- Fixed potential race conditions
This commit is contained in:
yangdx
2025-03-24 02:11:59 +08:00
parent 7e8a2c0e9b
commit ff9cb2138d
4 changed files with 42 additions and 39 deletions

View File

@@ -365,7 +365,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
args.vector_storage = get_env_value(
"LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE
)
# Get MAX_PARALLEL_INSERT from environment
args.max_parallel_insert = get_env_value("MAX_PARALLEL_INSERT", 2, int)
@@ -397,7 +397,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", True, bool
)
# Inject LLM temperature configuration
args.temperature = get_env_value("TEMPERATURE", 0.5, float)

View File

@@ -343,18 +343,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._id_to_meta = {}
async def index_done_callback(self) -> None:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
)
async with self._storage_lock:
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return False # Return error
async with self._storage_lock:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
)
async with self._storage_lock:
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:

View File

@@ -206,19 +206,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self) -> bool:
"""Save data to disk"""
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
async with self._storage_lock:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:

View File

@@ -401,18 +401,19 @@ class NetworkXStorage(BaseGraphStorage):
async def index_done_callback(self) -> bool:
"""Save data to disk"""
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Graph for {self.namespace} was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
async with self._storage_lock:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Graph for {self.namespace} was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
)
# Reset update flag
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock: