feat(storage): Add shared memory support for FAISS

This commit is contained in:
yangdx
2025-02-25 11:25:06 +08:00
parent 362321204f
commit e22e014f22

View File

@@ -2,6 +2,8 @@ import os
import time
import asyncio
from typing import Any, final
import threading
from multiprocessing import Manager
import json
import numpy as np
@@ -22,6 +24,27 @@ if not pm.is_installed("faiss"):
import faiss
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_indices = None
_shared_meta = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_indices, _shared_meta
with _init_lock:
if _manager is None:
try:
_manager = Manager()
_shared_indices = _manager.dict()
_shared_meta = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
@@ -50,18 +73,48 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
# Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP.
self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta = {}
# Attempt to load an existing index + metadata from disk
self._load_faiss_index()
# Ensure manager is initialized
_get_manager()
# Get or create namespace index and metadata
if self.namespace not in _shared_indices:
with _init_lock:
if self.namespace not in _shared_indices:
try:
# Create an empty Faiss index for inner product
index = faiss.IndexFlatIP(self._dim)
meta = {}
# Load existing index if available
if os.path.exists(self._faiss_index_file):
try:
index = faiss.read_index(self._faiss_index_file)
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
# Convert string keys back to int
meta = {int(k): v for k, v in stored_dict.items()}
logger.info(
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.")
index = faiss.IndexFlatIP(self._dim)
meta = {}
_shared_indices[self.namespace] = index
_shared_meta[self.namespace] = meta
except Exception as e:
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
raise RuntimeError(f"Faiss index initialization failed: {e}")
try:
self._index = _shared_indices[self.namespace]
self._id_to_meta = _shared_meta[self.namespace]
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""