diff --git a/README.md b/README.md index ad405e90..dd215b04 100644 --- a/README.md +++ b/README.md @@ -465,7 +465,36 @@ For production level scenarios you will most likely want to leverage an enterpri > > You can Compile the AGE from source code and fix it. +### Using Faiss for Storage +- Install the required dependencies: +``` +pip install faiss-cpu +``` +You can also install `faiss-gpu` if you have GPU support. +- Here we are using `sentence-transformers` but you can also use `OpenAIEmbedding` model with `3072` dimensions. + +``` +async def embedding_func(texts: list[str]) -> np.ndarray: + model = SentenceTransformer('all-MiniLM-L6-v2') + embeddings = model.encode(texts, convert_to_numpy=True) + return embeddings + +# Initialize LightRAG with the LLM model function and embedding function + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=8192, + func=embedding_func, + ), + vector_storage="FaissVectorDBStorage", + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": 0.3 # Your desired threshold + } + ) +``` ### Insert Custom KG diff --git a/examples/test_faiss.py b/examples/test_faiss.py new file mode 100644 index 00000000..e73c0bfc --- /dev/null +++ b/examples/test_faiss.py @@ -0,0 +1,104 @@ +import os +import logging +import numpy as np + +from dotenv import load_dotenv +from sentence_transformers import SentenceTransformer + +from openai import AzureOpenAI +from lightrag import LightRAG, QueryParam +from lightrag.utils import EmbeddingFunc +from lightrag.kg.faiss_impl import FaissVectorDBStorage + +# Configure Logging +logging.basicConfig(level=logging.INFO) + +# Load environment variables from .env file +load_dotenv() +AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") +AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") +AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") +AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") + +async def llm_model_func( + prompt, + system_prompt=None, + history_messages=[], + keyword_extraction=False, + **kwargs +) -> str: + + # Create a client for AzureOpenAI + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_OPENAI_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + + # Build the messages list for the conversation + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Call the LLM + chat_completion = client.chat.completions.create( + model=AZURE_OPENAI_DEPLOYMENT, + messages=messages, + temperature=kwargs.get("temperature", 0), + top_p=kwargs.get("top_p", 1), + n=kwargs.get("n", 1), + ) + + return chat_completion.choices[0].message.content + + +async def embedding_func(texts: list[str]) -> np.ndarray: + model = SentenceTransformer('all-MiniLM-L6-v2') + embeddings = model.encode(texts, convert_to_numpy=True) + return embeddings + +def main(): + + WORKING_DIR = "./dickens" + + # Initialize LightRAG with the LLM model function and embedding function + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=8192, + func=embedding_func, + ), + vector_storage="FaissVectorDBStorage", + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": 0.3 # Your desired threshold + } + ) + + # Insert the custom chunks into LightRAG + book1 = open("./book_1.txt", encoding="utf-8") + book2 = open("./book_2.txt", encoding="utf-8") + + rag.insert([book1.read(), book2.read()]) + + query_text = "What are the main themes?" + + print("Result (Naive):") + print(rag.query(query_text, param=QueryParam(mode="naive"))) + + print("\nResult (Local):") + print(rag.query(query_text, param=QueryParam(mode="local"))) + + print("\nResult (Global):") + print(rag.query(query_text, param=QueryParam(mode="global"))) + + print("\nResult (Hybrid):") + print(rag.query(query_text, param=QueryParam(mode="hybrid"))) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py new file mode 100644 index 00000000..1688e507 --- /dev/null +++ b/lightrag/kg/faiss_impl.py @@ -0,0 +1,318 @@ +import os +import time +import asyncio +import faiss +import json +import numpy as np +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass + +from lightrag.utils import ( + logger, + compute_mdhash_id, +) +from lightrag.base import ( + BaseVectorStorage, +) + + +@dataclass +class FaissVectorDBStorage(BaseVectorStorage): + """ + A Faiss-based Vector DB Storage for LightRAG. + Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. + """ + cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + + def __post_init__(self): + # Grab config values if available + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + self.cosine_better_than_threshold = config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + # Where to save index file if you want persistent storage + self._faiss_index_file = os.path.join( + self.global_config["working_dir"], f"faiss_index_{self.namespace}.index" + ) + self._meta_file = self._faiss_index_file + ".meta.json" + + self._max_batch_size = self.global_config["embedding_batch_num"] + # Embedding dimension (e.g. 768) must match your embedding function + self._dim = self.embedding_func.embedding_dim + + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). + # If you have a large number of vectors, you might want IVF or other indexes. + # For demonstration, we use a simple IndexFlatIP. + self._index = faiss.IndexFlatIP(self._dim) + + # Keep a local store for metadata, IDs, etc. + # Maps → metadata (including your original ID). + self._id_to_meta = {} + + # Attempt to load an existing index + metadata from disk + self._load_faiss_index() + + async def upsert(self, data: dict[str, dict]): + """ + Insert or update vectors in the Faiss index. + + data: { + "custom_id_1": { + "content": , + ...metadata... + }, + "custom_id_2": { + "content": , + ...metadata... + }, + ... + } + """ + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not data: + logger.warning("You are inserting empty data to the vector DB") + return [] + + current_time = time.time() + + # Prepare data for embedding + list_data = [] + contents = [] + for k, v in data.items(): + # Store only known meta fields if needed + meta = {mf: v[mf] for mf in self.meta_fields if mf in v} + meta["__id__"] = k + meta["__created_at__"] = current_time + list_data.append(meta) + contents.append(v["content"]) + + # Split into batches for embedding if needed + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + pbar = tqdm_async(total=len(batches), desc="Generating embeddings", unit="batch") + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + embeddings_list = await asyncio.gather(*embedding_tasks) + + # Flatten the list of arrays + embeddings = np.concatenate(embeddings_list, axis=0) + if len(embeddings) != len(list_data): + logger.error( + f"Embedding size mismatch. Embeddings: {len(embeddings)}, Data: {len(list_data)}" + ) + return [] + + # Normalize embeddings for cosine similarity (in-place) + faiss.normalize_L2(embeddings) + + # Upsert logic: + # 1. Identify which vectors to remove if they exist + # 2. Remove them + # 3. Add the new vectors + existing_ids_to_remove = [] + for meta, emb in zip(list_data, embeddings): + faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) + if faiss_internal_id is not None: + existing_ids_to_remove.append(faiss_internal_id) + + if existing_ids_to_remove: + self._remove_faiss_ids(existing_ids_to_remove) + + # Step 2: Add new vectors + start_idx = self._index.ntotal + self._index.add(embeddings) + + # Step 3: Store metadata + vector for each new ID + for i, meta in enumerate(list_data): + fid = start_idx + i + # Store the raw vector so we can rebuild if something is removed + meta["__vector__"] = embeddings[i].tolist() + self._id_to_meta[fid] = meta + + logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") + return [m["__id__"] for m in list_data] + + async def query(self, query: str, top_k=5): + """ + Search by a textual query; returns top_k results with their metadata + similarity distance. + """ + embedding = await self.embedding_func([query]) + # embedding is shape (1, dim) + embedding = np.array(embedding, dtype=np.float32) + faiss.normalize_L2(embedding) # we do in-place normalization + + logger.info( + f"Query: {query}, top_k: {top_k}, threshold: {self.cosine_better_than_threshold}" + ) + + # Perform the similarity search + distances, indices = self._index.search(embedding, top_k) + + distances = distances[0] + indices = indices[0] + + results = [] + for dist, idx in zip(distances, indices): + if idx == -1: + # Faiss returns -1 if no neighbor + continue + + # Cosine similarity threshold + if dist < self.cosine_better_than_threshold: + continue + + meta = self._id_to_meta.get(idx, {}) + results.append( + { + **meta, + "id": meta.get("__id__"), + "distance": float(dist), + "created_at": meta.get("__created_at__"), + } + ) + + return results + + @property + def client_storage(self): + # Return whatever structure LightRAG might need for debugging + return {"data": list(self._id_to_meta.values())} + + async def delete(self, ids: list[str]): + """ + Delete vectors for the provided custom IDs. + """ + logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + to_remove = [] + for cid in ids: + fid = self._find_faiss_id_by_custom_id(cid) + if fid is not None: + to_remove.append(fid) + + if to_remove: + self._remove_faiss_ids(to_remove) + logger.info(f"Successfully deleted {len(to_remove)} vectors from {self.namespace}") + + async def delete_entity(self, entity_name: str): + """ + Delete a single entity by computing its hashed ID + the same way your code does it with `compute_mdhash_id`. + """ + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + await self.delete([entity_id]) + + async def delete_entity_relation(self, entity_name: str): + """ + Delete relations for a given entity by scanning metadata. + """ + logger.debug(f"Searching relations for entity {entity_name}") + relations = [] + for fid, meta in self._id_to_meta.items(): + if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: + relations.append(fid) + + logger.debug(f"Found {len(relations)} relations for {entity_name}") + if relations: + self._remove_faiss_ids(relations) + logger.debug(f"Deleted {len(relations)} relations for {entity_name}") + + async def index_done_callback(self): + """ + Called after indexing is done (save Faiss index + metadata). + """ + self._save_faiss_index() + logger.info("Faiss index saved successfully.") + + # -------------------------------------------------------------------------------- + # Internal helper methods + # -------------------------------------------------------------------------------- + + def _find_faiss_id_by_custom_id(self, custom_id: str): + """ + Return the Faiss internal ID for a given custom ID, or None if not found. + """ + for fid, meta in self._id_to_meta.items(): + if meta.get("__id__") == custom_id: + return fid + return None + + def _remove_faiss_ids(self, fid_list): + """ + Remove a list of internal Faiss IDs from the index. + Because IndexFlatIP doesn't support 'removals', + we rebuild the index excluding those vectors. + """ + keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] + + # Rebuild the index + vectors_to_keep = [] + new_id_to_meta = {} + for new_fid, old_fid in enumerate(keep_fids): + vec_meta = self._id_to_meta[old_fid] + vectors_to_keep.append(vec_meta["__vector__"]) # stored as list + new_id_to_meta[new_fid] = vec_meta + + # Re-init index + self._index = faiss.IndexFlatIP(self._dim) + if vectors_to_keep: + arr = np.array(vectors_to_keep, dtype=np.float32) + self._index.add(arr) + + self._id_to_meta = new_id_to_meta + + def _save_faiss_index(self): + """ + Save the current Faiss index + metadata to disk so it can persist across runs. + """ + faiss.write_index(self._index, self._faiss_index_file) + + # Save metadata dict to JSON. Convert all keys to strings for JSON storage. + # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } + # We'll keep the int -> dict, but JSON requires string keys. + serializable_dict = {} + for fid, meta in self._id_to_meta.items(): + serializable_dict[str(fid)] = meta + + with open(self._meta_file, "w", encoding="utf-8") as f: + json.dump(serializable_dict, f) + + def _load_faiss_index(self): + """ + Load the Faiss index + metadata from disk if it exists, + and rebuild in-memory structures so we can query. + """ + if not os.path.exists(self._faiss_index_file): + logger.warning("No existing Faiss index file found. Starting fresh.") + return + + try: + # Load the Faiss index + self._index = faiss.read_index(self._faiss_index_file) + # Load metadata + with open(self._meta_file, "r", encoding="utf-8") as f: + stored_dict = json.load(f) + + # Convert string keys back to int + self._id_to_meta = {} + for fid_str, meta in stored_dict.items(): + fid = int(fid_str) + self._id_to_meta[fid] = meta + + logger.info( + f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}" + ) + except Exception as e: + logger.error(f"Failed to load Faiss index or metadata: {e}") + logger.warning("Starting with an empty Faiss index.") + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 92fc954f..22db6994 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -60,6 +60,7 @@ STORAGES = { "PGGraphStorage": ".kg.postgres_impl", "GremlinStorage": ".kg.gremlin_impl", "PGDocStatusStorage": ".kg.postgres_impl", + "FaissVectorDBStorage": ".kg.faiss_impl", }