Refactor Faiss index access with helper method to improve code organization

This commit is contained in:
yangdx
2025-02-27 15:09:19 +08:00
parent f007ebf006
commit 438e4780a8

View File

@@ -74,6 +74,13 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._id_to_meta.update({}) self._id_to_meta.update({})
self._load_faiss_index() self._load_faiss_index()
def _get_index(self):
"""
Helper method to get the correct index object based on multiprocess mode.
Returns the actual index object that can be used for operations.
"""
return self._index.value if is_multiprocess else self._index
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -142,11 +149,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._remove_faiss_ids(existing_ids_to_remove) self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors # Step 2: Add new vectors
start_idx = (self._index.value if is_multiprocess else self._index).ntotal index = self._get_index()
if is_multiprocess: start_idx = index.ntotal
self._index.value.add(embeddings) index.add(embeddings)
else:
self._index.add(embeddings)
# Step 3: Store metadata + vector for each new ID # Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data): for i, meta in enumerate(list_data):
@@ -173,9 +178,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Perform the similarity search # Perform the similarity search
with self._storage_lock: with self._storage_lock:
distances, indices = ( distances, indices = self._get_index().search(embedding, top_k)
self._index.value if is_multiprocess else self._index
).search(embedding, top_k)
distances = distances[0] distances = distances[0]
indices = indices[0] indices = indices[0]
@@ -303,7 +306,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
""" """
with self._storage_lock: with self._storage_lock:
faiss.write_index( faiss.write_index(
self._index.value if is_multiprocess else self._index, self._get_index(),
self._faiss_index_file, self._faiss_index_file,
) )