feat: add multi-process support for FAISS vector storage

• Add storage update flag and locks
• Support cross-process index reload
• Add async initialize method
This commit is contained in:
yangdx
2025-03-01 12:42:30 +08:00
parent d4f6dcfd54
commit 35bcfca28f

View File

@@ -15,7 +15,12 @@ if not pm.is_installed("faiss"):
pm.install("faiss") pm.install("faiss")
import faiss # type: ignore import faiss # type: ignore
from threading import Lock as ThreadLock from .shared_storage import (
get_storage_lock,
get_update_flag,
set_all_update_flags,
is_multiprocess,
)
@final @final
@@ -45,29 +50,43 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Embedding dimension (e.g. 768) must match your embedding function # Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim self._dim = self.embedding_func.embedding_dim
self._storage_lock = ThreadLock()
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes. # If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP. # For demonstration, we use a simple IndexFlatIP.
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc. # Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID). # Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta = {} self._id_to_meta = {}
# Attempt to load an existing index + metadata from disk self._load_faiss_index()
with self._storage_lock:
self._load_faiss_index()
def _get_index(self): async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
async def _get_index(self):
"""Check if the shtorage should be reloaded""" """Check if the shtorage should be reloaded"""
# Acquire lock to prevent concurrent read and write
with self._storage_lock:
# Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or \
(not is_multiprocess and self.storage_updated):
logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process")
# Reload data
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
return self._index return self._index
async def index_done_callback(self) -> None:
with self._storage_lock:
self._save_faiss_index()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -135,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._remove_faiss_ids(existing_ids_to_remove) self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors # Step 2: Add new vectors
index = self._get_index() index = await self._get_index()
start_idx = index.ntotal start_idx = index.ntotal
index.add(embeddings) index.add(embeddings)
@@ -163,7 +182,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
# Perform the similarity search # Perform the similarity search
distances, indices = self._get_index().search(embedding, top_k) index = await self._get_index()
distances, indices = index().search(embedding, top_k)
distances = distances[0] distances = distances[0]
indices = indices[0] indices = indices[0]
@@ -316,3 +336,33 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.warning("Starting with an empty Faiss index.") logger.warning("Starting with an empty Faiss index.")
self._index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {} self._id_to_meta = {}
async def index_done_callback(self) -> None:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...")
with self._storage_lock:
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()
self.storage_updated.value = False
return False # Return error
# Acquire lock and perform persistence
async with self._storage_lock:
try:
# Save data to disk
self._save_faiss_index()
# Set all update flags to False
await set_all_update_flags(self.namespace)
# Reset own update flag to avoid self-reloading
if is_multiprocess:
self.storage_updated.value = False
else:
self.storage_updated = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
return False # Return error
return True # Return success