feat: add multi-process support for FAISS vector storage
• Add storage update flag and locks • Support cross-process index reload • Add async initialize method
This commit is contained in:
@@ -15,7 +15,12 @@ if not pm.is_installed("faiss"):
|
|||||||
pm.install("faiss")
|
pm.install("faiss")
|
||||||
|
|
||||||
import faiss # type: ignore
|
import faiss # type: ignore
|
||||||
from threading import Lock as ThreadLock
|
from .shared_storage import (
|
||||||
|
get_storage_lock,
|
||||||
|
get_update_flag,
|
||||||
|
set_all_update_flags,
|
||||||
|
is_multiprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@@ -45,29 +50,43 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
# Embedding dimension (e.g. 768) must match your embedding function
|
# Embedding dimension (e.g. 768) must match your embedding function
|
||||||
self._dim = self.embedding_func.embedding_dim
|
self._dim = self.embedding_func.embedding_dim
|
||||||
self._storage_lock = ThreadLock()
|
|
||||||
|
|
||||||
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
||||||
# If you have a large number of vectors, you might want IVF or other indexes.
|
# If you have a large number of vectors, you might want IVF or other indexes.
|
||||||
# For demonstration, we use a simple IndexFlatIP.
|
# For demonstration, we use a simple IndexFlatIP.
|
||||||
self._index = faiss.IndexFlatIP(self._dim)
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
|
|
||||||
# Keep a local store for metadata, IDs, etc.
|
# Keep a local store for metadata, IDs, etc.
|
||||||
# Maps <int faiss_id> → metadata (including your original ID).
|
# Maps <int faiss_id> → metadata (including your original ID).
|
||||||
self._id_to_meta = {}
|
self._id_to_meta = {}
|
||||||
|
|
||||||
# Attempt to load an existing index + metadata from disk
|
self._load_faiss_index()
|
||||||
with self._storage_lock:
|
|
||||||
self._load_faiss_index()
|
|
||||||
|
|
||||||
def _get_index(self):
|
async def initialize(self):
|
||||||
|
"""Initialize storage data"""
|
||||||
|
# Get the update flag for cross-process update notification
|
||||||
|
self.storage_updated = await get_update_flag(self.namespace)
|
||||||
|
# Get the storage lock for use in other methods
|
||||||
|
self._storage_lock = get_storage_lock()
|
||||||
|
|
||||||
|
async def _get_index(self):
|
||||||
"""Check if the shtorage should be reloaded"""
|
"""Check if the shtorage should be reloaded"""
|
||||||
|
# Acquire lock to prevent concurrent read and write
|
||||||
|
with self._storage_lock:
|
||||||
|
# Check if storage was updated by another process
|
||||||
|
if (is_multiprocess and self.storage_updated.value) or \
|
||||||
|
(not is_multiprocess and self.storage_updated):
|
||||||
|
logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process")
|
||||||
|
# Reload data
|
||||||
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
|
self._id_to_meta = {}
|
||||||
|
self._load_faiss_index()
|
||||||
|
if is_multiprocess:
|
||||||
|
self.storage_updated.value = False
|
||||||
|
else:
|
||||||
|
self.storage_updated = False
|
||||||
return self._index
|
return self._index
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
|
||||||
with self._storage_lock:
|
|
||||||
self._save_faiss_index()
|
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert or update vectors in the Faiss index.
|
Insert or update vectors in the Faiss index.
|
||||||
@@ -135,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._remove_faiss_ids(existing_ids_to_remove)
|
self._remove_faiss_ids(existing_ids_to_remove)
|
||||||
|
|
||||||
# Step 2: Add new vectors
|
# Step 2: Add new vectors
|
||||||
index = self._get_index()
|
index = await self._get_index()
|
||||||
start_idx = index.ntotal
|
start_idx = index.ntotal
|
||||||
index.add(embeddings)
|
index.add(embeddings)
|
||||||
|
|
||||||
@@ -163,7 +182,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Perform the similarity search
|
# Perform the similarity search
|
||||||
distances, indices = self._get_index().search(embedding, top_k)
|
index = await self._get_index()
|
||||||
|
distances, indices = index().search(embedding, top_k)
|
||||||
|
|
||||||
distances = distances[0]
|
distances = distances[0]
|
||||||
indices = indices[0]
|
indices = indices[0]
|
||||||
@@ -316,3 +336,33 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.warning("Starting with an empty Faiss index.")
|
logger.warning("Starting with an empty Faiss index.")
|
||||||
self._index = faiss.IndexFlatIP(self._dim)
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
self._id_to_meta = {}
|
self._id_to_meta = {}
|
||||||
|
|
||||||
|
async def index_done_callback(self) -> None:
|
||||||
|
# Check if storage was updated by another process
|
||||||
|
if is_multiprocess and self.storage_updated.value:
|
||||||
|
# Storage was updated by another process, reload data instead of saving
|
||||||
|
logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...")
|
||||||
|
with self._storage_lock:
|
||||||
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
|
self._id_to_meta = {}
|
||||||
|
self._load_faiss_index()
|
||||||
|
self.storage_updated.value = False
|
||||||
|
return False # Return error
|
||||||
|
|
||||||
|
# Acquire lock and perform persistence
|
||||||
|
async with self._storage_lock:
|
||||||
|
try:
|
||||||
|
# Save data to disk
|
||||||
|
self._save_faiss_index()
|
||||||
|
# Set all update flags to False
|
||||||
|
await set_all_update_flags(self.namespace)
|
||||||
|
# Reset own update flag to avoid self-reloading
|
||||||
|
if is_multiprocess:
|
||||||
|
self.storage_updated.value = False
|
||||||
|
else:
|
||||||
|
self.storage_updated = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
|
||||||
|
return False # Return error
|
||||||
|
|
||||||
|
return True # Return success
|
||||||
|
Reference in New Issue
Block a user