feat(storage): Add shared memory support for FAISS

This commit is contained in:
yangdx
2025-02-25 11:25:06 +08:00
parent 362321204f
commit e22e014f22

View File

@@ -2,6 +2,8 @@ import os
import time import time
import asyncio import asyncio
from typing import Any, final from typing import Any, final
import threading
from multiprocessing import Manager
import json import json
import numpy as np import numpy as np
@@ -22,6 +24,27 @@ if not pm.is_installed("faiss"):
import faiss import faiss
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_indices = None
_shared_meta = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_indices, _shared_meta
with _init_lock:
if _manager is None:
try:
_manager = Manager()
_shared_indices = _manager.dict()
_shared_meta = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final @final
@dataclass @dataclass
@@ -51,17 +74,47 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Embedding dimension (e.g. 768) must match your embedding function # Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim self._dim = self.embedding_func.embedding_dim
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). # Ensure manager is initialized
# If you have a large number of vectors, you might want IVF or other indexes. _get_manager()
# For demonstration, we use a simple IndexFlatIP.
self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc. # Get or create namespace index and metadata
# Maps <int faiss_id> → metadata (including your original ID). if self.namespace not in _shared_indices:
self._id_to_meta = {} with _init_lock:
if self.namespace not in _shared_indices:
try:
# Create an empty Faiss index for inner product
index = faiss.IndexFlatIP(self._dim)
meta = {}
# Attempt to load an existing index + metadata from disk # Load existing index if available
self._load_faiss_index() if os.path.exists(self._faiss_index_file):
try:
index = faiss.read_index(self._faiss_index_file)
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
# Convert string keys back to int
meta = {int(k): v for k, v in stored_dict.items()}
logger.info(
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.")
index = faiss.IndexFlatIP(self._dim)
meta = {}
_shared_indices[self.namespace] = index
_shared_meta[self.namespace] = meta
except Exception as e:
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
raise RuntimeError(f"Faiss index initialization failed: {e}")
try:
self._index = _shared_indices[self.namespace]
self._id_to_meta = _shared_meta[self.namespace]
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """