feat: add multi-process support for FAISS vector storage
• Add storage update flag and locks • Support cross-process index reload • Add async initialize method
This commit is contained in:
@@ -15,7 +15,12 @@ if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
import faiss # type: ignore
|
||||
from threading import Lock as ThreadLock
|
||||
from .shared_storage import (
|
||||
get_storage_lock,
|
||||
get_update_flag,
|
||||
set_all_update_flags,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@@ -45,28 +50,42 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# Embedding dimension (e.g. 768) must match your embedding function
|
||||
self._dim = self.embedding_func.embedding_dim
|
||||
self._storage_lock = ThreadLock()
|
||||
|
||||
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
||||
# If you have a large number of vectors, you might want IVF or other indexes.
|
||||
# For demonstration, we use a simple IndexFlatIP.
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
|
||||
# Keep a local store for metadata, IDs, etc.
|
||||
# Maps <int faiss_id> → metadata (including your original ID).
|
||||
self._id_to_meta = {}
|
||||
|
||||
# Attempt to load an existing index + metadata from disk
|
||||
with self._storage_lock:
|
||||
self._load_faiss_index()
|
||||
self._load_faiss_index()
|
||||
|
||||
def _get_index(self):
|
||||
async def initialize(self):
|
||||
"""Initialize storage data"""
|
||||
# Get the update flag for cross-process update notification
|
||||
self.storage_updated = await get_update_flag(self.namespace)
|
||||
# Get the storage lock for use in other methods
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
async def _get_index(self):
|
||||
"""Check if the shtorage should be reloaded"""
|
||||
# Acquire lock to prevent concurrent read and write
|
||||
with self._storage_lock:
|
||||
# Check if storage was updated by another process
|
||||
if (is_multiprocess and self.storage_updated.value) or \
|
||||
(not is_multiprocess and self.storage_updated):
|
||||
logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process")
|
||||
# Reload data
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
return self._index
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
with self._storage_lock:
|
||||
self._save_faiss_index()
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
@@ -135,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._remove_faiss_ids(existing_ids_to_remove)
|
||||
|
||||
# Step 2: Add new vectors
|
||||
index = self._get_index()
|
||||
index = await self._get_index()
|
||||
start_idx = index.ntotal
|
||||
index.add(embeddings)
|
||||
|
||||
@@ -163,7 +182,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
# Perform the similarity search
|
||||
distances, indices = self._get_index().search(embedding, top_k)
|
||||
index = await self._get_index()
|
||||
distances, indices = index().search(embedding, top_k)
|
||||
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
@@ -316,3 +336,33 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
logger.warning("Starting with an empty Faiss index.")
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# Check if storage was updated by another process
|
||||
if is_multiprocess and self.storage_updated.value:
|
||||
# Storage was updated by another process, reload data instead of saving
|
||||
logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...")
|
||||
with self._storage_lock:
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
self._id_to_meta = {}
|
||||
self._load_faiss_index()
|
||||
self.storage_updated.value = False
|
||||
return False # Return error
|
||||
|
||||
# Acquire lock and perform persistence
|
||||
async with self._storage_lock:
|
||||
try:
|
||||
# Save data to disk
|
||||
self._save_faiss_index()
|
||||
# Set all update flags to False
|
||||
await set_all_update_flags(self.namespace)
|
||||
# Reset own update flag to avoid self-reloading
|
||||
if is_multiprocess:
|
||||
self.storage_updated.value = False
|
||||
else:
|
||||
self.storage_updated = False
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
|
||||
return False # Return error
|
||||
|
||||
return True # Return success
|
||||
|
Reference in New Issue
Block a user