feat(storage): Add shared memory support for FAISS
This commit is contained in:
@@ -2,6 +2,8 @@ import os
|
|||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, final
|
from typing import Any, final
|
||||||
|
import threading
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -22,6 +24,27 @@ if not pm.is_installed("faiss"):
|
|||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
|
# Global variables for shared memory management
|
||||||
|
_init_lock = threading.Lock()
|
||||||
|
_manager = None
|
||||||
|
_shared_indices = None
|
||||||
|
_shared_meta = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_manager():
|
||||||
|
"""Get or create the global manager instance"""
|
||||||
|
global _manager, _shared_indices, _shared_meta
|
||||||
|
with _init_lock:
|
||||||
|
if _manager is None:
|
||||||
|
try:
|
||||||
|
_manager = Manager()
|
||||||
|
_shared_indices = _manager.dict()
|
||||||
|
_shared_meta = _manager.dict()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize shared memory manager: {e}")
|
||||||
|
raise RuntimeError(f"Shared memory initialization failed: {e}")
|
||||||
|
return _manager
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -50,18 +73,48 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
# Embedding dimension (e.g. 768) must match your embedding function
|
# Embedding dimension (e.g. 768) must match your embedding function
|
||||||
self._dim = self.embedding_func.embedding_dim
|
self._dim = self.embedding_func.embedding_dim
|
||||||
|
|
||||||
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
# Ensure manager is initialized
|
||||||
# If you have a large number of vectors, you might want IVF or other indexes.
|
_get_manager()
|
||||||
# For demonstration, we use a simple IndexFlatIP.
|
|
||||||
self._index = faiss.IndexFlatIP(self._dim)
|
# Get or create namespace index and metadata
|
||||||
|
if self.namespace not in _shared_indices:
|
||||||
# Keep a local store for metadata, IDs, etc.
|
with _init_lock:
|
||||||
# Maps <int faiss_id> → metadata (including your original ID).
|
if self.namespace not in _shared_indices:
|
||||||
self._id_to_meta = {}
|
try:
|
||||||
|
# Create an empty Faiss index for inner product
|
||||||
# Attempt to load an existing index + metadata from disk
|
index = faiss.IndexFlatIP(self._dim)
|
||||||
self._load_faiss_index()
|
meta = {}
|
||||||
|
|
||||||
|
# Load existing index if available
|
||||||
|
if os.path.exists(self._faiss_index_file):
|
||||||
|
try:
|
||||||
|
index = faiss.read_index(self._faiss_index_file)
|
||||||
|
with open(self._meta_file, "r", encoding="utf-8") as f:
|
||||||
|
stored_dict = json.load(f)
|
||||||
|
# Convert string keys back to int
|
||||||
|
meta = {int(k): v for k, v in stored_dict.items()}
|
||||||
|
logger.info(
|
||||||
|
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
||||||
|
logger.warning("Starting with an empty Faiss index.")
|
||||||
|
index = faiss.IndexFlatIP(self._dim)
|
||||||
|
meta = {}
|
||||||
|
|
||||||
|
_shared_indices[self.namespace] = index
|
||||||
|
_shared_meta[self.namespace] = meta
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
|
||||||
|
raise RuntimeError(f"Faiss index initialization failed: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._index = _shared_indices[self.namespace]
|
||||||
|
self._id_to_meta = _shared_meta[self.namespace]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to access shared memory: {e}")
|
||||||
|
raise RuntimeError(f"Cannot access shared memory: {e}")
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user