feat(storage): Add shared memory support for FAISS
This commit is contained in:
@@ -2,6 +2,8 @@ import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Any, final
|
||||
import threading
|
||||
from multiprocessing import Manager
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
@@ -22,6 +24,27 @@ if not pm.is_installed("faiss"):
|
||||
|
||||
import faiss
|
||||
|
||||
# Global variables for shared memory management
|
||||
_init_lock = threading.Lock()
|
||||
_manager = None
|
||||
_shared_indices = None
|
||||
_shared_meta = None
|
||||
|
||||
|
||||
def _get_manager():
|
||||
"""Get or create the global manager instance"""
|
||||
global _manager, _shared_indices, _shared_meta
|
||||
with _init_lock:
|
||||
if _manager is None:
|
||||
try:
|
||||
_manager = Manager()
|
||||
_shared_indices = _manager.dict()
|
||||
_shared_meta = _manager.dict()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||
return _manager
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
@@ -50,18 +73,48 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# Embedding dimension (e.g. 768) must match your embedding function
|
||||
self._dim = self.embedding_func.embedding_dim
|
||||
|
||||
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
||||
# If you have a large number of vectors, you might want IVF or other indexes.
|
||||
# For demonstration, we use a simple IndexFlatIP.
|
||||
self._index = faiss.IndexFlatIP(self._dim)
|
||||
|
||||
# Keep a local store for metadata, IDs, etc.
|
||||
# Maps <int faiss_id> → metadata (including your original ID).
|
||||
self._id_to_meta = {}
|
||||
|
||||
# Attempt to load an existing index + metadata from disk
|
||||
self._load_faiss_index()
|
||||
|
||||
# Ensure manager is initialized
|
||||
_get_manager()
|
||||
|
||||
# Get or create namespace index and metadata
|
||||
if self.namespace not in _shared_indices:
|
||||
with _init_lock:
|
||||
if self.namespace not in _shared_indices:
|
||||
try:
|
||||
# Create an empty Faiss index for inner product
|
||||
index = faiss.IndexFlatIP(self._dim)
|
||||
meta = {}
|
||||
|
||||
# Load existing index if available
|
||||
if os.path.exists(self._faiss_index_file):
|
||||
try:
|
||||
index = faiss.read_index(self._faiss_index_file)
|
||||
with open(self._meta_file, "r", encoding="utf-8") as f:
|
||||
stored_dict = json.load(f)
|
||||
# Convert string keys back to int
|
||||
meta = {int(k): v for k, v in stored_dict.items()}
|
||||
logger.info(
|
||||
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
||||
logger.warning("Starting with an empty Faiss index.")
|
||||
index = faiss.IndexFlatIP(self._dim)
|
||||
meta = {}
|
||||
|
||||
_shared_indices[self.namespace] = index
|
||||
_shared_meta[self.namespace] = meta
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
|
||||
raise RuntimeError(f"Faiss index initialization failed: {e}")
|
||||
|
||||
try:
|
||||
self._index = _shared_indices[self.namespace]
|
||||
self._id_to_meta = _shared_meta[self.namespace]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access shared memory: {e}")
|
||||
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user