From e8d0d065f3bcb2b0ebce3c5d9c87cdcd701eab8e Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 23:35:09 +0800 Subject: [PATCH] fix: Improve async handling and FAISS storage reliability - Add async context manager support - Fix embedding data type conversion - Improve error handling in FAISS ops - Add multiprocess storage sync --- lightrag/api/README.md | 2 +- lightrag/kg/faiss_impl.py | 74 +++++++++++++++++++-------------------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 86f18271..35062cad 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -186,7 +186,7 @@ LightRAG supports binding to various LLM/Embedding backends: * openai & openai compatible * azure_openai -Use environment variables `LLM_BINDING ` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING ` or CLI argument `--embedding-binding` to select LLM backend type. +Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type. ### Storage Types Supported diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index e71f77a8..940ba73d 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -71,7 +71,7 @@ class FaissVectorDBStorage(BaseVectorStorage): async def _get_index(self): """Check if the shtorage should be reloaded""" # Acquire lock to prevent concurrent read and write - with self._storage_lock: + async with self._storage_lock: # Check if storage was updated by another process if (is_multiprocess and self.storage_updated.value) or ( not is_multiprocess and self.storage_updated @@ -139,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage): ) return [] - # Normalize embeddings for cosine similarity (in-place) + # Convert to float32 and normalize embeddings for cosine similarity (in-place) + embeddings = embeddings.astype(np.float32) faiss.normalize_L2(embeddings) # Upsert logic: @@ -153,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage): existing_ids_to_remove.append(faiss_internal_id) if existing_ids_to_remove: - self._remove_faiss_ids(existing_ids_to_remove) + await self._remove_faiss_ids(existing_ids_to_remove) # Step 2: Add new vectors index = await self._get_index() @@ -185,7 +186,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Perform the similarity search index = await self._get_index() - distances, indices = index().search(embedding, top_k) + distances, indices = index.search(embedding, top_k) distances = distances[0] indices = indices[0] @@ -229,7 +230,7 @@ class FaissVectorDBStorage(BaseVectorStorage): to_remove.append(fid) if to_remove: - self._remove_faiss_ids(to_remove) + await self._remove_faiss_ids(to_remove) logger.debug( f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" ) @@ -251,7 +252,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Found {len(relations)} relations for {entity_name}") if relations: - self._remove_faiss_ids(relations) + await self._remove_faiss_ids(relations) logger.debug(f"Deleted {len(relations)} relations for {entity_name}") # -------------------------------------------------------------------------------- @@ -267,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage): return fid return None - def _remove_faiss_ids(self, fid_list): + async def _remove_faiss_ids(self, fid_list): """ Remove a list of internal Faiss IDs from the index. Because IndexFlatIP doesn't support 'removals', @@ -283,7 +284,7 @@ class FaissVectorDBStorage(BaseVectorStorage): vectors_to_keep.append(vec_meta["__vector__"]) # stored as list new_id_to_meta[new_fid] = vec_meta - with self._storage_lock: + async with self._storage_lock: # Re-init index self._index = faiss.IndexFlatIP(self._dim) if vectors_to_keep: @@ -339,35 +340,34 @@ class FaissVectorDBStorage(BaseVectorStorage): self._index = faiss.IndexFlatIP(self._dim) self._id_to_meta = {} - -async def index_done_callback(self) -> None: - # Check if storage was updated by another process - if is_multiprocess and self.storage_updated.value: - # Storage was updated by another process, reload data instead of saving - logger.warning( - f"Storage for FAISS {self.namespace} was updated by another process, reloading..." - ) - with self._storage_lock: - self._index = faiss.IndexFlatIP(self._dim) - self._id_to_meta = {} - self._load_faiss_index() - self.storage_updated.value = False - return False # Return error - - # Acquire lock and perform persistence - async with self._storage_lock: - try: - # Save data to disk - self._save_faiss_index() - # Notify other processes that data has been updated - await set_all_update_flags(self.namespace) - # Reset own update flag to avoid self-reloading - if is_multiprocess: + async def index_done_callback(self) -> None: + # Check if storage was updated by another process + if is_multiprocess and self.storage_updated.value: + # Storage was updated by another process, reload data instead of saving + logger.warning( + f"Storage for FAISS {self.namespace} was updated by another process, reloading..." + ) + async with self._storage_lock: + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} + self._load_faiss_index() self.storage_updated.value = False - else: - self.storage_updated = False - except Exception as e: - logger.error(f"Error saving FAISS index for {self.namespace}: {e}") return False # Return error - return True # Return success + # Acquire lock and perform persistence + async with self._storage_lock: + try: + # Save data to disk + self._save_faiss_index() + # Notify other processes that data has been updated + await set_all_update_flags(self.namespace) + # Reset own update flag to avoid self-reloading + if is_multiprocess: + self.storage_updated.value = False + else: + self.storage_updated = False + except Exception as e: + logger.error(f"Error saving FAISS index for {self.namespace}: {e}") + return False # Return error + + return True # Return success