fix: Improve async handling and FAISS storage reliability

- Add async context manager support
- Fix embedding data type conversion
- Improve error handling in FAISS ops
- Add multiprocess storage sync
This commit is contained in:
yangdx
2025-03-01 23:35:09 +08:00
parent 9aef112d51
commit e8d0d065f3
2 changed files with 38 additions and 38 deletions

View File

@@ -71,7 +71,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
async def _get_index(self): async def _get_index(self):
"""Check if the shtorage should be reloaded""" """Check if the shtorage should be reloaded"""
# Acquire lock to prevent concurrent read and write # Acquire lock to prevent concurrent read and write
with self._storage_lock: async with self._storage_lock:
# Check if storage was updated by another process # Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or ( if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated not is_multiprocess and self.storage_updated
@@ -139,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
return [] return []
# Normalize embeddings for cosine similarity (in-place) # Convert to float32 and normalize embeddings for cosine similarity (in-place)
embeddings = embeddings.astype(np.float32)
faiss.normalize_L2(embeddings) faiss.normalize_L2(embeddings)
# Upsert logic: # Upsert logic:
@@ -153,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
existing_ids_to_remove.append(faiss_internal_id) existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove: if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove) await self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors # Step 2: Add new vectors
index = await self._get_index() index = await self._get_index()
@@ -185,7 +186,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Perform the similarity search # Perform the similarity search
index = await self._get_index() index = await self._get_index()
distances, indices = index().search(embedding, top_k) distances, indices = index.search(embedding, top_k)
distances = distances[0] distances = distances[0]
indices = indices[0] indices = indices[0]
@@ -229,7 +230,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
to_remove.append(fid) to_remove.append(fid)
if to_remove: if to_remove:
self._remove_faiss_ids(to_remove) await self._remove_faiss_ids(to_remove)
logger.debug( logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
) )
@@ -251,7 +252,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.debug(f"Found {len(relations)} relations for {entity_name}") logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations: if relations:
self._remove_faiss_ids(relations) await self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}") logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
@@ -267,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
return fid return fid
return None return None
def _remove_faiss_ids(self, fid_list): async def _remove_faiss_ids(self, fid_list):
""" """
Remove a list of internal Faiss IDs from the index. Remove a list of internal Faiss IDs from the index.
Because IndexFlatIP doesn't support 'removals', Because IndexFlatIP doesn't support 'removals',
@@ -283,7 +284,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta new_id_to_meta[new_fid] = vec_meta
with self._storage_lock: async with self._storage_lock:
# Re-init index # Re-init index
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep: if vectors_to_keep:
@@ -339,7 +340,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Check if storage was updated by another process # Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value: if is_multiprocess and self.storage_updated.value:
@@ -347,7 +347,7 @@ async def index_done_callback(self) -> None:
logger.warning( logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..." f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
) )
with self._storage_lock: async with self._storage_lock:
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
self._load_faiss_index() self._load_faiss_index()