diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index d0ef6ed0..f244c288 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -15,7 +15,12 @@ if not pm.is_installed("faiss"): pm.install("faiss") import faiss # type: ignore -from threading import Lock as ThreadLock +from .shared_storage import ( + get_storage_lock, + get_update_flag, + set_all_update_flags, + is_multiprocess, +) @final @@ -45,29 +50,43 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim - self._storage_lock = ThreadLock() - + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # If you have a large number of vectors, you might want IVF or other indexes. # For demonstration, we use a simple IndexFlatIP. self._index = faiss.IndexFlatIP(self._dim) - # Keep a local store for metadata, IDs, etc. # Maps → metadata (including your original ID). self._id_to_meta = {} - # Attempt to load an existing index + metadata from disk - with self._storage_lock: - self._load_faiss_index() + self._load_faiss_index() - def _get_index(self): + async def initialize(self): + """Initialize storage data""" + # Get the update flag for cross-process update notification + self.storage_updated = await get_update_flag(self.namespace) + # Get the storage lock for use in other methods + self._storage_lock = get_storage_lock() + + async def _get_index(self): """Check if the shtorage should be reloaded""" + # Acquire lock to prevent concurrent read and write + with self._storage_lock: + # Check if storage was updated by another process + if (is_multiprocess and self.storage_updated.value) or \ + (not is_multiprocess and self.storage_updated): + logger.info(f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process") + # Reload data + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False return self._index - async def index_done_callback(self) -> None: - with self._storage_lock: - self._save_faiss_index() - + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -135,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._remove_faiss_ids(existing_ids_to_remove) # Step 2: Add new vectors - index = self._get_index() + index = await self._get_index() start_idx = index.ntotal index.add(embeddings) @@ -163,7 +182,8 @@ class FaissVectorDBStorage(BaseVectorStorage): ) # Perform the similarity search - distances, indices = self._get_index().search(embedding, top_k) + index = await self._get_index() + distances, indices = index().search(embedding, top_k) distances = distances[0] indices = indices[0] @@ -316,3 +336,33 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.warning("Starting with an empty Faiss index.") self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} + +async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning(f"Storage for FAISS {self.namespace} was updated by another process, reloading...") + with self._storage_lock: + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() + self.storage_updated.value = False + return False # Return error + + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + self._save_faiss_index() + # Set all update flags to False + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + except Exception as e: + logger.error(f"Error saving FAISS index for {self.namespace}: {e}") + return False # Return error + + return True # Return success