Ensure thread safety in storage update callbacks
- Added storage lock in index_done_callback - Fixed potential race conditions
This commit is contained in:
@@ -343,18 +343,19 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._id_to_meta = {}
|
self._id_to_meta = {}
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# Check if storage was updated by another process
|
async with self._storage_lock:
|
||||||
if is_multiprocess and self.storage_updated.value:
|
# Check if storage was updated by another process
|
||||||
# Storage was updated by another process, reload data instead of saving
|
if is_multiprocess and self.storage_updated.value:
|
||||||
logger.warning(
|
# Storage was updated by another process, reload data instead of saving
|
||||||
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
|
logger.warning(
|
||||||
)
|
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
|
||||||
async with self._storage_lock:
|
)
|
||||||
self._index = faiss.IndexFlatIP(self._dim)
|
async with self._storage_lock:
|
||||||
self._id_to_meta = {}
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
self._load_faiss_index()
|
self._id_to_meta = {}
|
||||||
self.storage_updated.value = False
|
self._load_faiss_index()
|
||||||
return False # Return error
|
self.storage_updated.value = False
|
||||||
|
return False # Return error
|
||||||
|
|
||||||
# Acquire lock and perform persistence
|
# Acquire lock and perform persistence
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
@@ -206,19 +206,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
async def index_done_callback(self) -> bool:
|
async def index_done_callback(self) -> bool:
|
||||||
"""Save data to disk"""
|
"""Save data to disk"""
|
||||||
# Check if storage was updated by another process
|
async with self._storage_lock:
|
||||||
if is_multiprocess and self.storage_updated.value:
|
# Check if storage was updated by another process
|
||||||
# Storage was updated by another process, reload data instead of saving
|
if is_multiprocess and self.storage_updated.value:
|
||||||
logger.warning(
|
# Storage was updated by another process, reload data instead of saving
|
||||||
f"Storage for {self.namespace} was updated by another process, reloading..."
|
logger.warning(
|
||||||
)
|
f"Storage for {self.namespace} was updated by another process, reloading..."
|
||||||
self._client = NanoVectorDB(
|
)
|
||||||
self.embedding_func.embedding_dim,
|
self._client = NanoVectorDB(
|
||||||
storage_file=self._client_file_name,
|
self.embedding_func.embedding_dim,
|
||||||
)
|
storage_file=self._client_file_name,
|
||||||
# Reset update flag
|
)
|
||||||
self.storage_updated.value = False
|
# Reset update flag
|
||||||
return False # Return error
|
self.storage_updated.value = False
|
||||||
|
return False # Return error
|
||||||
|
|
||||||
# Acquire lock and perform persistence
|
# Acquire lock and perform persistence
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
@@ -401,18 +401,19 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def index_done_callback(self) -> bool:
|
async def index_done_callback(self) -> bool:
|
||||||
"""Save data to disk"""
|
"""Save data to disk"""
|
||||||
# Check if storage was updated by another process
|
async with self._storage_lock:
|
||||||
if is_multiprocess and self.storage_updated.value:
|
# Check if storage was updated by another process
|
||||||
# Storage was updated by another process, reload data instead of saving
|
if is_multiprocess and self.storage_updated.value:
|
||||||
logger.warning(
|
# Storage was updated by another process, reload data instead of saving
|
||||||
f"Graph for {self.namespace} was updated by another process, reloading..."
|
logger.warning(
|
||||||
)
|
f"Graph for {self.namespace} was updated by another process, reloading..."
|
||||||
self._graph = (
|
)
|
||||||
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
self._graph = (
|
||||||
)
|
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
|
||||||
# Reset update flag
|
)
|
||||||
self.storage_updated.value = False
|
# Reset update flag
|
||||||
return False # Return error
|
self.storage_updated.value = False
|
||||||
|
return False # Return error
|
||||||
|
|
||||||
# Acquire lock and perform persistence
|
# Acquire lock and perform persistence
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
Reference in New Issue
Block a user