diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 2ac0899e..4324e965 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -2,6 +2,8 @@ import os import time import asyncio from typing import Any, final +import threading +from multiprocessing import Manager import json import numpy as np @@ -22,6 +24,27 @@ if not pm.is_installed("faiss"): import faiss +# Global variables for shared memory management +_init_lock = threading.Lock() +_manager = None +_shared_indices = None +_shared_meta = None + + +def _get_manager(): + """Get or create the global manager instance""" + global _manager, _shared_indices, _shared_meta + with _init_lock: + if _manager is None: + try: + _manager = Manager() + _shared_indices = _manager.dict() + _shared_meta = _manager.dict() + except Exception as e: + logger.error(f"Failed to initialize shared memory manager: {e}") + raise RuntimeError(f"Shared memory initialization failed: {e}") + return _manager + @final @dataclass @@ -50,18 +73,48 @@ class FaissVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] # Embedding dimension (e.g. 768) must match your embedding function self._dim = self.embedding_func.embedding_dim - - # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). - # If you have a large number of vectors, you might want IVF or other indexes. - # For demonstration, we use a simple IndexFlatIP. - self._index = faiss.IndexFlatIP(self._dim) - - # Keep a local store for metadata, IDs, etc. - # Maps → metadata (including your original ID). - self._id_to_meta = {} - - # Attempt to load an existing index + metadata from disk - self._load_faiss_index() + + # Ensure manager is initialized + _get_manager() + + # Get or create namespace index and metadata + if self.namespace not in _shared_indices: + with _init_lock: + if self.namespace not in _shared_indices: + try: + # Create an empty Faiss index for inner product + index = faiss.IndexFlatIP(self._dim) + meta = {} + + # Load existing index if available + if os.path.exists(self._faiss_index_file): + try: + index = faiss.read_index(self._faiss_index_file) + with open(self._meta_file, "r", encoding="utf-8") as f: + stored_dict = json.load(f) + # Convert string keys back to int + meta = {int(k): v for k, v in stored_dict.items()} + logger.info( + f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}" + ) + except Exception as e: + logger.error(f"Failed to load Faiss index or metadata: {e}") + logger.warning("Starting with an empty Faiss index.") + index = faiss.IndexFlatIP(self._dim) + meta = {} + + _shared_indices[self.namespace] = index + _shared_meta[self.namespace] = meta + except Exception as e: + logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}") + raise RuntimeError(f"Faiss index initialization failed: {e}") + + try: + self._index = _shared_indices[self.namespace] + self._id_to_meta = _shared_meta[self.namespace] + except Exception as e: + logger.error(f"Failed to access shared memory: {e}") + raise RuntimeError(f"Cannot access shared memory: {e}") async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """