Refactor storage implementations to support both single and multi-process modes

• Add shared storage management module
• Support process/thread lock based on mode
This commit is contained in:
yangdx
2025-02-26 05:38:38 +08:00
parent 8050b0f91b
commit 2752a764ae
10 changed files with 608 additions and 623 deletions

View File

@@ -2,48 +2,21 @@ import os
import time
import asyncio
from typing import Any, final
import threading
import json
import numpy as np
from dataclasses import dataclass
import pipmaster as pm
from lightrag.api.utils_api import manager as main_process_manager
from lightrag.utils import (
logger,
compute_mdhash_id,
)
from lightrag.base import (
BaseVectorStorage,
)
from lightrag.utils import logger,compute_mdhash_id
from lightrag.base import BaseVectorStorage
from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
if not pm.is_installed("faiss"):
pm.install("faiss")
import faiss # type: ignore
# Global variables for shared memory management
_init_lock = threading.Lock()
_manager = None
_shared_indices = None
_shared_meta = None
def _get_manager():
"""Get or create the global manager instance"""
global _manager, _shared_indices, _shared_meta
with _init_lock:
if _manager is None:
try:
_manager = main_process_manager
_shared_indices = _manager.dict()
_shared_meta = _manager.dict()
except Exception as e:
logger.error(f"Failed to initialize shared memory manager: {e}")
raise RuntimeError(f"Shared memory initialization failed: {e}")
return _manager
@final
@dataclass
@@ -72,48 +45,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
# Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim
self._storage_lock = get_storage_lock()
# Ensure manager is initialized
_get_manager()
self._index = get_namespace_object('faiss_indices')
self._id_to_meta = get_namespace_data('faiss_meta')
# Get or create namespace index and metadata
if self.namespace not in _shared_indices:
with _init_lock:
if self.namespace not in _shared_indices:
try:
# Create an empty Faiss index for inner product
index = faiss.IndexFlatIP(self._dim)
meta = {}
# Load existing index if available
if os.path.exists(self._faiss_index_file):
try:
index = faiss.read_index(self._faiss_index_file)
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
# Convert string keys back to int
meta = {int(k): v for k, v in stored_dict.items()}
logger.info(
f"Faiss index loaded with {index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.")
index = faiss.IndexFlatIP(self._dim)
meta = {}
_shared_indices[self.namespace] = index
_shared_meta[self.namespace] = meta
except Exception as e:
logger.error(f"Failed to initialize Faiss index for namespace {self.namespace}: {e}")
raise RuntimeError(f"Faiss index initialization failed: {e}")
try:
self._index = _shared_indices[self.namespace]
self._id_to_meta = _shared_meta[self.namespace]
except Exception as e:
logger.error(f"Failed to access shared memory: {e}")
raise RuntimeError(f"Cannot access shared memory: {e}")
with self._storage_lock:
if is_multiprocess:
if self._index.value is None:
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP.
self._index.value = faiss.IndexFlatIP(self._dim)
else:
if self._index is None:
self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta.update({})
# Attempt to load an existing index + metadata from disk
self._load_faiss_index()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
@@ -168,32 +122,36 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Normalize embeddings for cosine similarity (in-place)
faiss.normalize_L2(embeddings)
# Upsert logic:
# 1. Identify which vectors to remove if they exist
# 2. Remove them
# 3. Add the new vectors
existing_ids_to_remove = []
for meta, emb in zip(list_data, embeddings):
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
if faiss_internal_id is not None:
existing_ids_to_remove.append(faiss_internal_id)
with self._storage_lock:
# Upsert logic:
# 1. Identify which vectors to remove if they exist
# 2. Remove them
# 3. Add the new vectors
existing_ids_to_remove = []
for meta, emb in zip(list_data, embeddings):
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
if faiss_internal_id is not None:
existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove)
if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors
start_idx = self._index.ntotal
self._index.add(embeddings)
# Step 2: Add new vectors
start_idx = (self._index.value if is_multiprocess else self._index).ntotal
if is_multiprocess:
self._index.value.add(embeddings)
else:
self._index.add(embeddings)
# Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data):
fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta[fid] = meta
# Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data):
fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta.update({fid: meta})
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data]
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""
@@ -209,54 +167,57 @@ class FaissVectorDBStorage(BaseVectorStorage):
)
# Perform the similarity search
distances, indices = self._index.search(embedding, top_k)
with self._storage_lock:
distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
distances = distances[0]
indices = indices[0]
distances = distances[0]
indices = indices[0]
results = []
for dist, idx in zip(distances, indices):
if idx == -1:
# Faiss returns -1 if no neighbor
continue
results = []
for dist, idx in zip(distances, indices):
if idx == -1:
# Faiss returns -1 if no neighbor
continue
# Cosine similarity threshold
if dist < self.cosine_better_than_threshold:
continue
# Cosine similarity threshold
if dist < self.cosine_better_than_threshold:
continue
meta = self._id_to_meta.get(idx, {})
results.append(
{
**meta,
"id": meta.get("__id__"),
"distance": float(dist),
"created_at": meta.get("__created_at__"),
}
)
meta = self._id_to_meta.get(idx, {})
results.append(
{
**meta,
"id": meta.get("__id__"),
"distance": float(dist),
"created_at": meta.get("__created_at__"),
}
)
return results
return results
@property
def client_storage(self):
# Return whatever structure LightRAG might need for debugging
return {"data": list(self._id_to_meta.values())}
with self._storage_lock:
return {"data": list(self._id_to_meta.values())}
async def delete(self, ids: list[str]):
"""
Delete vectors for the provided custom IDs.
"""
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
to_remove = []
for cid in ids:
fid = self._find_faiss_id_by_custom_id(cid)
if fid is not None:
to_remove.append(fid)
with self._storage_lock:
to_remove = []
for cid in ids:
fid = self._find_faiss_id_by_custom_id(cid)
if fid is not None:
to_remove.append(fid)
if to_remove:
self._remove_faiss_ids(to_remove)
logger.info(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
if to_remove:
self._remove_faiss_ids(to_remove)
logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
@@ -268,18 +229,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
Delete relations for a given entity by scanning metadata.
"""
logger.debug(f"Searching relations for entity {entity_name}")
relations = []
for fid, meta in self._id_to_meta.items():
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
relations.append(fid)
with self._storage_lock:
relations = []
for fid, meta in self._id_to_meta.items():
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
relations.append(fid)
logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations:
self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations:
self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self) -> None:
self._save_faiss_index()
with self._storage_lock:
self._save_faiss_index()
# --------------------------------------------------------------------------------
# Internal helper methods
@@ -289,10 +252,11 @@ class FaissVectorDBStorage(BaseVectorStorage):
"""
Return the Faiss internal ID for a given custom ID, or None if not found.
"""
for fid, meta in self._id_to_meta.items():
if meta.get("__id__") == custom_id:
return fid
return None
with self._storage_lock:
for fid, meta in self._id_to_meta.items():
if meta.get("__id__") == custom_id:
return fid
return None
def _remove_faiss_ids(self, fid_list):
"""
@@ -300,39 +264,45 @@ class FaissVectorDBStorage(BaseVectorStorage):
Because IndexFlatIP doesn't support 'removals',
we rebuild the index excluding those vectors.
"""
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
with self._storage_lock:
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
# Rebuild the index
vectors_to_keep = []
new_id_to_meta = {}
for new_fid, old_fid in enumerate(keep_fids):
vec_meta = self._id_to_meta[old_fid]
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta
# Rebuild the index
vectors_to_keep = []
new_id_to_meta = {}
for new_fid, old_fid in enumerate(keep_fids):
vec_meta = self._id_to_meta[old_fid]
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta
# Re-init index
self._index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32)
self._index.add(arr)
# Re-init index
new_index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32)
new_index.add(arr)
if is_multiprocess:
self._index.value = new_index
else:
self._index = new_index
self._id_to_meta = new_id_to_meta
self._id_to_meta.update(new_id_to_meta)
def _save_faiss_index(self):
"""
Save the current Faiss index + metadata to disk so it can persist across runs.
"""
faiss.write_index(self._index, self._faiss_index_file)
with self._storage_lock:
faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
# We'll keep the int -> dict, but JSON requires string keys.
serializable_dict = {}
for fid, meta in self._id_to_meta.items():
serializable_dict[str(fid)] = meta
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
# We'll keep the int -> dict, but JSON requires string keys.
serializable_dict = {}
for fid, meta in self._id_to_meta.items():
serializable_dict[str(fid)] = meta
with open(self._meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f)
with open(self._meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f)
def _load_faiss_index(self):
"""
@@ -345,22 +315,31 @@ class FaissVectorDBStorage(BaseVectorStorage):
try:
# Load the Faiss index
self._index = faiss.read_index(self._faiss_index_file)
loaded_index = faiss.read_index(self._faiss_index_file)
if is_multiprocess:
self._index.value = loaded_index
else:
self._index = loaded_index
# Load metadata
with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f)
# Convert string keys back to int
self._id_to_meta = {}
self._id_to_meta.update({})
for fid_str, meta in stored_dict.items():
fid = int(fid_str)
self._id_to_meta[fid] = meta
logger.info(
f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.")
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
new_index = faiss.IndexFlatIP(self._dim)
if is_multiprocess:
self._index.value = new_index
else:
self._index = new_index
self._id_to_meta.update({})