Refactor Faiss index access with helper method to improve code organization
This commit is contained in:
@@ -74,6 +74,13 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._id_to_meta.update({})
|
self._id_to_meta.update({})
|
||||||
self._load_faiss_index()
|
self._load_faiss_index()
|
||||||
|
|
||||||
|
def _get_index(self):
|
||||||
|
"""
|
||||||
|
Helper method to get the correct index object based on multiprocess mode.
|
||||||
|
Returns the actual index object that can be used for operations.
|
||||||
|
"""
|
||||||
|
return self._index.value if is_multiprocess else self._index
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert or update vectors in the Faiss index.
|
Insert or update vectors in the Faiss index.
|
||||||
@@ -142,11 +149,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._remove_faiss_ids(existing_ids_to_remove)
|
self._remove_faiss_ids(existing_ids_to_remove)
|
||||||
|
|
||||||
# Step 2: Add new vectors
|
# Step 2: Add new vectors
|
||||||
start_idx = (self._index.value if is_multiprocess else self._index).ntotal
|
index = self._get_index()
|
||||||
if is_multiprocess:
|
start_idx = index.ntotal
|
||||||
self._index.value.add(embeddings)
|
index.add(embeddings)
|
||||||
else:
|
|
||||||
self._index.add(embeddings)
|
|
||||||
|
|
||||||
# Step 3: Store metadata + vector for each new ID
|
# Step 3: Store metadata + vector for each new ID
|
||||||
for i, meta in enumerate(list_data):
|
for i, meta in enumerate(list_data):
|
||||||
@@ -173,9 +178,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
# Perform the similarity search
|
# Perform the similarity search
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
distances, indices = (
|
distances, indices = self._get_index().search(embedding, top_k)
|
||||||
self._index.value if is_multiprocess else self._index
|
|
||||||
).search(embedding, top_k)
|
|
||||||
|
|
||||||
distances = distances[0]
|
distances = distances[0]
|
||||||
indices = indices[0]
|
indices = indices[0]
|
||||||
@@ -303,7 +306,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
"""
|
"""
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
faiss.write_index(
|
faiss.write_index(
|
||||||
self._index.value if is_multiprocess else self._index,
|
self._get_index(),
|
||||||
self._faiss_index_file,
|
self._faiss_index_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user