fix: Improve async handling and FAISS storage reliability

- Add async context manager support
- Fix embedding data type conversion
- Improve error handling in FAISS ops
- Add multiprocess storage sync
This commit is contained in:
yangdx
2025-03-01 23:35:09 +08:00
parent 9aef112d51
commit e8d0d065f3
2 changed files with 38 additions and 38 deletions

View File

@@ -186,7 +186,7 @@ LightRAG supports binding to various LLM/Embedding backends:
* openai & openai compatible
* azure_openai
Use environment variables `LLM_BINDING ` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING ` or CLI argument `--embedding-binding` to select LLM backend type.
Use environment variables `LLM_BINDING` or CLI argument `--llm-binding` to select LLM backend type. Use environment variables `EMBEDDING_BINDING` or CLI argument `--embedding-binding` to select LLM backend type.
### Storage Types Supported

View File

@@ -71,7 +71,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
async def _get_index(self):
"""Check if the shtorage should be reloaded"""
# Acquire lock to prevent concurrent read and write
with self._storage_lock:
async with self._storage_lock:
# Check if storage was updated by another process
if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
@@ -139,7 +139,8 @@ class FaissVectorDBStorage(BaseVectorStorage):
)
return []
# Normalize embeddings for cosine similarity (in-place)
# Convert to float32 and normalize embeddings for cosine similarity (in-place)
embeddings = embeddings.astype(np.float32)
faiss.normalize_L2(embeddings)
# Upsert logic:
@@ -153,7 +154,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove)
await self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors
index = await self._get_index()
@@ -185,7 +186,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Perform the similarity search
index = await self._get_index()
distances, indices = index().search(embedding, top_k)
distances, indices = index.search(embedding, top_k)
distances = distances[0]
indices = indices[0]
@@ -229,7 +230,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
to_remove.append(fid)
if to_remove:
self._remove_faiss_ids(to_remove)
await self._remove_faiss_ids(to_remove)
logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
@@ -251,7 +252,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations:
self._remove_faiss_ids(relations)
await self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
# --------------------------------------------------------------------------------
@@ -267,7 +268,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
return fid
return None
def _remove_faiss_ids(self, fid_list):
async def _remove_faiss_ids(self, fid_list):
"""
Remove a list of internal Faiss IDs from the index.
Because IndexFlatIP doesn't support 'removals',
@@ -283,7 +284,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta
with self._storage_lock:
async with self._storage_lock:
# Re-init index
self._index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep:
@@ -339,15 +340,14 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
async def index_done_callback(self) -> None:
async def index_done_callback(self) -> None:
# Check if storage was updated by another process
if is_multiprocess and self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
)
with self._storage_lock:
async with self._storage_lock:
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
self._load_faiss_index()