diff --git a/README.md b/README.md index ad405e90..dd215b04 100644 --- a/README.md +++ b/README.md @@ -465,7 +465,36 @@ For production level scenarios you will most likely want to leverage an enterpri > > You can Compile the AGE from source code and fix it. +### Using Faiss for Storage +- Install the required dependencies: +``` +pip install faiss-cpu +``` +You can also install `faiss-gpu` if you have GPU support. +- Here we are using `sentence-transformers` but you can also use `OpenAIEmbedding` model with `3072` dimensions. + +``` +async def embedding_func(texts: list[str]) -> np.ndarray: + model = SentenceTransformer('all-MiniLM-L6-v2') + embeddings = model.encode(texts, convert_to_numpy=True) + return embeddings + +# Initialize LightRAG with the LLM model function and embedding function + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=8192, + func=embedding_func, + ), + vector_storage="FaissVectorDBStorage", + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": 0.3 # Your desired threshold + } + ) +``` ### Insert Custom KG diff --git a/examples/test_faiss.py b/examples/test_faiss.py new file mode 100644 index 00000000..ab0ef9f7 --- /dev/null +++ b/examples/test_faiss.py @@ -0,0 +1,99 @@ +import os +import logging +import numpy as np + +from dotenv import load_dotenv +from sentence_transformers import SentenceTransformer + +from openai import AzureOpenAI +from lightrag import LightRAG, QueryParam +from lightrag.utils import EmbeddingFunc + +# Configure Logging +logging.basicConfig(level=logging.INFO) + +# Load environment variables from .env file +load_dotenv() +AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") +AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") +AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") +AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + # Create a client for AzureOpenAI + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_OPENAI_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + + # Build the messages list for the conversation + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Call the LLM + chat_completion = client.chat.completions.create( + model=AZURE_OPENAI_DEPLOYMENT, + messages=messages, + temperature=kwargs.get("temperature", 0), + top_p=kwargs.get("top_p", 1), + n=kwargs.get("n", 1), + ) + + return chat_completion.choices[0].message.content + + +async def embedding_func(texts: list[str]) -> np.ndarray: + model = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = model.encode(texts, convert_to_numpy=True) + return embeddings + + +def main(): + WORKING_DIR = "./dickens" + + # Initialize LightRAG with the LLM model function and embedding function + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=8192, + func=embedding_func, + ), + vector_storage="FaissVectorDBStorage", + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": 0.3 # Your desired threshold + }, + ) + + # Insert the custom chunks into LightRAG + book1 = open("./book_1.txt", encoding="utf-8") + book2 = open("./book_2.txt", encoding="utf-8") + + rag.insert([book1.read(), book2.read()]) + + query_text = "What are the main themes?" + + print("Result (Naive):") + print(rag.query(query_text, param=QueryParam(mode="naive"))) + + print("\nResult (Local):") + print(rag.query(query_text, param=QueryParam(mode="local"))) + + print("\nResult (Global):") + print(rag.query(query_text, param=QueryParam(mode="global"))) + + print("\nResult (Hybrid):") + print(rag.query(query_text, param=QueryParam(mode="hybrid"))) + + +if __name__ == "__main__": + main() diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index e162f5ec..e1b24731 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1,9 +1,37 @@ -from fastapi import FastAPI, HTTPException, File, UploadFile, Form, Request +from fastapi import ( + FastAPI, + HTTPException, + File, + UploadFile, + Form, + Request, + BackgroundTasks, +) + +# Backend (Python) +# Add this to store progress globally +from typing import Dict +import threading + +# Global progress tracker +scan_progress: Dict = { + "is_scanning": False, + "current_file": "", + "indexed_count": 0, + "total_files": 0, + "progress": 0, +} + +# Lock for thread-safe operations +progress_lock = threading.Lock() + +import json +import os + from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import logging import argparse -import json import time import re from typing import List, Dict, Any, Optional, Union @@ -16,7 +44,6 @@ from pathlib import Path import shutil import aiofiles from ascii_colors import trace_exception, ASCIIColors -import os import sys import configparser @@ -538,7 +565,7 @@ class DocumentManager: # Create input directory if it doesn't exist self.input_dir.mkdir(parents=True, exist_ok=True) - def scan_directory(self) -> List[Path]: + def scan_directory_for_new_files(self) -> List[Path]: """Scan input directory for new files""" new_files = [] for ext in self.supported_extensions: @@ -547,6 +574,14 @@ class DocumentManager: new_files.append(file_path) return new_files + def scan_directory(self) -> List[Path]: + """Scan input directory for new files""" + new_files = [] + for ext in self.supported_extensions: + for file_path in self.input_dir.rglob(f"*{ext}"): + new_files.append(file_path) + return new_files + def mark_as_indexed(self, file_path: Path): """Mark a file as indexed""" self.indexed_files.add(file_path) @@ -730,7 +765,7 @@ def create_app(args): # Startup logic if args.auto_scan_at_startup: try: - new_files = doc_manager.scan_directory() + new_files = doc_manager.scan_directory_for_new_files() for file_path in new_files: try: await index_file(file_path) @@ -983,42 +1018,59 @@ def create_app(args): logging.warning(f"No content extracted from file: {file_path}") @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) - async def scan_for_new_documents(): - """ - Manually trigger scanning for new documents in the directory managed by `doc_manager`. + async def scan_for_new_documents(background_tasks: BackgroundTasks): + """Trigger the scanning process""" + global scan_progress - This endpoint facilitates manual initiation of a document scan to identify and index new files. - It processes all newly detected files, attempts indexing each file, logs any errors that occur, - and returns a summary of the operation. + with progress_lock: + if scan_progress["is_scanning"]: + return {"status": "already_scanning"} - Returns: - dict: A dictionary containing: - - "status" (str): Indicates success or failure of the scanning process. - - "indexed_count" (int): The number of successfully indexed documents. - - "total_documents" (int): Total number of documents that have been indexed so far. + scan_progress["is_scanning"] = True + scan_progress["indexed_count"] = 0 + scan_progress["progress"] = 0 + + # Start the scanning process in the background + background_tasks.add_task(run_scanning_process) + + return {"status": "scanning_started"} + + async def run_scanning_process(): + """Background task to scan and index documents""" + global scan_progress - Raises: - HTTPException: If an error occurs during the document scanning process, a 500 status - code is returned with details about the exception. - """ try: - new_files = doc_manager.scan_directory() - indexed_count = 0 + new_files = doc_manager.scan_directory_for_new_files() + scan_progress["total_files"] = len(new_files) for file_path in new_files: try: + with progress_lock: + scan_progress["current_file"] = os.path.basename(file_path) + await index_file(file_path) - indexed_count += 1 + + with progress_lock: + scan_progress["indexed_count"] += 1 + scan_progress["progress"] = ( + scan_progress["indexed_count"] + / scan_progress["total_files"] + ) * 100 + except Exception as e: logging.error(f"Error indexing file {file_path}: {str(e)}") - return { - "status": "success", - "indexed_count": indexed_count, - "total_documents": len(doc_manager.indexed_files), - } except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + logging.error(f"Error during scanning process: {str(e)}") + finally: + with progress_lock: + scan_progress["is_scanning"] = False + + @app.get("/documents/scan-progress") + async def get_scan_progress(): + """Get the current scanning progress""" + with progress_lock: + return scan_progress @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir(file: UploadFile = File(...)): @@ -1849,7 +1901,7 @@ def create_app(args): "status": "healthy", "working_directory": str(args.working_dir), "input_directory": str(args.input_dir), - "indexed_files": files, + "indexed_files": [str(f) for f in files], "indexed_files_count": len(files), "configuration": { # LLM configuration binding/host address (if applicable)/model (if applicable) diff --git a/lightrag/api/static/index.html b/lightrag/api/static/index.html index 60900c03..c9659d5e 100644 --- a/lightrag/api/static/index.html +++ b/lightrag/api/static/index.html @@ -98,7 +98,7 @@ - + diff --git a/lightrag/api/static/js/lightrag_api.js b/lightrag/api/static/js/api.js similarity index 86% rename from lightrag/api/static/js/lightrag_api.js rename to lightrag/api/static/js/api.js index 3c2ff69c..b610eb10 100644 --- a/lightrag/api/static/js/lightrag_api.js +++ b/lightrag/api/static/js/api.js @@ -1,375 +1,408 @@ -// State management -const state = { - apiKey: localStorage.getItem('apiKey') || '', - files: [], - indexedFiles: [], - currentPage: 'file-manager' -}; - -// Utility functions -const showToast = (message, duration = 3000) => { - const toast = document.getElementById('toast'); - toast.querySelector('div').textContent = message; - toast.classList.remove('hidden'); - setTimeout(() => toast.classList.add('hidden'), duration); -}; - -const fetchWithAuth = async (url, options = {}) => { - const headers = { - ...(options.headers || {}), - ...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {}) - }; - return fetch(url, { ...options, headers }); -}; - -// Page renderers -const pages = { - 'file-manager': () => ` -
-

File Manager

- -
- - -
- -
-

Selected Files

-
-
- - - - -
-

Indexed Files

-
-
- - - -
- `, - - 'query': () => ` -
-

Query Database

- -
-
- - -
- -
- - -
- - - -
-
-
- `, - - 'knowledge-graph': () => ` -
-
- - - -

Under Construction

-

Knowledge graph visualization will be available in a future update.

-
-
- `, - - 'status': () => ` -
-

System Status

-
-
-

System Health

-
-
-
-

Configuration

-
-
-
-
- `, - - 'settings': () => ` -
-

Settings

- -
-
-
- - -
- - -
-
-
- ` -}; - -// Page handlers -const handlers = { - 'file-manager': () => { - const fileInput = document.getElementById('fileInput'); - const dropZone = fileInput.parentElement.parentElement; - const fileList = document.querySelector('#fileList div'); - const indexedFiles = document.querySelector('#indexedFiles div'); - const uploadBtn = document.getElementById('uploadBtn'); - - const updateFileList = () => { - fileList.innerHTML = state.files.map(file => ` -
- ${file.name} - -
- `).join(''); - }; - - const updateIndexedFiles = async () => { - const response = await fetchWithAuth('/health'); - const data = await response.json(); - indexedFiles.innerHTML = data.indexed_files.map(file => ` -
- ${file} -
- `).join(''); - }; - - dropZone.addEventListener('dragover', (e) => { - e.preventDefault(); - dropZone.classList.add('border-blue-500'); - }); - - dropZone.addEventListener('dragleave', () => { - dropZone.classList.remove('border-blue-500'); - }); - - dropZone.addEventListener('drop', (e) => { - e.preventDefault(); - dropZone.classList.remove('border-blue-500'); - const files = Array.from(e.dataTransfer.files); - state.files.push(...files); - updateFileList(); - }); - - fileInput.addEventListener('change', () => { - state.files.push(...Array.from(fileInput.files)); - updateFileList(); - }); - - uploadBtn.addEventListener('click', async () => { - if (state.files.length === 0) { - showToast('Please select files to upload'); - return; - } - let apiKey = localStorage.getItem('apiKey') || ''; - const progress = document.getElementById('uploadProgress'); - const progressBar = progress.querySelector('div'); - const statusText = document.getElementById('uploadStatus'); - progress.classList.remove('hidden'); - - for (let i = 0; i < state.files.length; i++) { - const formData = new FormData(); - formData.append('file', state.files[i]); - - try { - await fetch('/documents/upload', { - method: 'POST', - headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {}, - body: formData - }); - - const percentage = ((i + 1) / state.files.length) * 100; - progressBar.style.width = `${percentage}%`; - statusText.textContent = `${i + 1}/${state.files.length}`; - } catch (error) { - console.error('Upload error:', error); - } - } - progress.classList.add('hidden'); - }); - rescanBtn.addEventListener('click', async () => { - let apiKey = localStorage.getItem('apiKey') || ''; - const progress = document.getElementById('uploadProgress'); - const progressBar = progress.querySelector('div'); - const statusText = document.getElementById('uploadStatus'); - progress.classList.remove('hidden'); - try { - const scan_output = await fetch('/documents/scan', { - method: 'GET', - }); - statusText.textContent = scan_output.data; - } catch (error) { - console.error('Upload error:', error); - } - progress.classList.add('hidden'); - }); - updateIndexedFiles(); - }, - - 'query': () => { - const queryBtn = document.getElementById('queryBtn'); - const queryInput = document.getElementById('queryInput'); - const queryMode = document.getElementById('queryMode'); - const queryResult = document.getElementById('queryResult'); - - let apiKey = localStorage.getItem('apiKey') || ''; - - queryBtn.addEventListener('click', async () => { - const query = queryInput.value.trim(); - if (!query) { - showToast('Please enter a query'); - return; - } - - queryBtn.disabled = true; - queryBtn.innerHTML = ` - - - - - Processing... - `; - - try { - const response = await fetchWithAuth('/query', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - query, - mode: queryMode.value, - stream: false, - only_need_context: false - }) - }); - - const data = await response.json(); - queryResult.innerHTML = marked.parse(data.response); - } catch (error) { - showToast('Error processing query'); - } finally { - queryBtn.disabled = false; - queryBtn.textContent = 'Send Query'; - } - }); - }, - - 'status': async () => { - const healthStatus = document.getElementById('healthStatus'); - const configStatus = document.getElementById('configStatus'); - - try { - const response = await fetchWithAuth('/health'); - const data = await response.json(); - - healthStatus.innerHTML = ` -
-
-
- ${data.status} -
-
-

Working Directory: ${data.working_directory}

-

Input Directory: ${data.input_directory}

-

Indexed Files: ${data.indexed_files_count}

-
-
- `; - - configStatus.innerHTML = Object.entries(data.configuration) - .map(([key, value]) => ` -
- ${key}: - ${value} -
- `).join(''); - } catch (error) { - showToast('Error fetching status'); - } - }, - - 'settings': () => { - const saveBtn = document.getElementById('saveSettings'); - const apiKeyInput = document.getElementById('apiKeyInput'); - - saveBtn.addEventListener('click', () => { - state.apiKey = apiKeyInput.value; - localStorage.setItem('apiKey', state.apiKey); - showToast('Settings saved successfully'); - }); - } -}; - -// Navigation handling -document.querySelectorAll('.nav-item').forEach(item => { - item.addEventListener('click', (e) => { - e.preventDefault(); - const page = item.dataset.page; - document.getElementById('content').innerHTML = pages[page](); - if (handlers[page]) handlers[page](); - state.currentPage = page; - }); -}); - -// Initialize with file manager -document.getElementById('content').innerHTML = pages['file-manager'](); -handlers['file-manager'](); - -// Global functions -window.removeFile = (fileName) => { - state.files = state.files.filter(file => file.name !== fileName); - document.querySelector('#fileList div').innerHTML = state.files.map(file => ` -
- ${file.name} - -
- `).join(''); +// State management +const state = { + apiKey: localStorage.getItem('apiKey') || '', + files: [], + indexedFiles: [], + currentPage: 'file-manager' +}; + +// Utility functions +const showToast = (message, duration = 3000) => { + const toast = document.getElementById('toast'); + toast.querySelector('div').textContent = message; + toast.classList.remove('hidden'); + setTimeout(() => toast.classList.add('hidden'), duration); +}; + +const fetchWithAuth = async (url, options = {}) => { + const headers = { + ...(options.headers || {}), + ...(state.apiKey ? { 'X-API-Key': state.apiKey } : {}) // Use X-API-Key instead of Bearer + }; + return fetch(url, { ...options, headers }); +}; + + +// Page renderers +const pages = { + 'file-manager': () => ` +
+

File Manager

+ +
+ + +
+ +
+

Selected Files

+
+
+ +
+ + + +
+ +
+

Indexed Files

+
+
+ + +
+ `, + + 'query': () => ` +
+

Query Database

+ +
+
+ + +
+ +
+ + +
+ + + +
+
+
+ `, + + 'knowledge-graph': () => ` +
+
+ + + +

Under Construction

+

Knowledge graph visualization will be available in a future update.

+
+
+ `, + + 'status': () => ` +
+

System Status

+
+
+

System Health

+
+
+
+

Configuration

+
+
+
+
+ `, + + 'settings': () => ` +
+

Settings

+ +
+
+
+ + +
+ + +
+
+
+ ` +}; + +// Page handlers +const handlers = { + 'file-manager': () => { + const fileInput = document.getElementById('fileInput'); + const dropZone = fileInput.parentElement.parentElement; + const fileList = document.querySelector('#fileList div'); + const indexedFiles = document.querySelector('#indexedFiles div'); + const uploadBtn = document.getElementById('uploadBtn'); + + const updateFileList = () => { + fileList.innerHTML = state.files.map(file => ` +
+ ${file.name} + +
+ `).join(''); + }; + + const updateIndexedFiles = async () => { + const response = await fetchWithAuth('/health'); + const data = await response.json(); + indexedFiles.innerHTML = data.indexed_files.map(file => ` +
+ ${file} +
+ `).join(''); + }; + + dropZone.addEventListener('dragover', (e) => { + e.preventDefault(); + dropZone.classList.add('border-blue-500'); + }); + + dropZone.addEventListener('dragleave', () => { + dropZone.classList.remove('border-blue-500'); + }); + + dropZone.addEventListener('drop', (e) => { + e.preventDefault(); + dropZone.classList.remove('border-blue-500'); + const files = Array.from(e.dataTransfer.files); + state.files.push(...files); + updateFileList(); + }); + + fileInput.addEventListener('change', () => { + state.files.push(...Array.from(fileInput.files)); + updateFileList(); + }); + + uploadBtn.addEventListener('click', async () => { + if (state.files.length === 0) { + showToast('Please select files to upload'); + return; + } + let apiKey = localStorage.getItem('apiKey') || ''; + const progress = document.getElementById('uploadProgress'); + const progressBar = progress.querySelector('div'); + const statusText = document.getElementById('uploadStatus'); + progress.classList.remove('hidden'); + + for (let i = 0; i < state.files.length; i++) { + const formData = new FormData(); + formData.append('file', state.files[i]); + + try { + await fetch('/documents/upload', { + method: 'POST', + headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {}, + body: formData + }); + + const percentage = ((i + 1) / state.files.length) * 100; + progressBar.style.width = `${percentage}%`; + statusText.textContent = `${i + 1}/${state.files.length}`; + } catch (error) { + console.error('Upload error:', error); + } + } + progress.classList.add('hidden'); + }); + + rescanBtn.addEventListener('click', async () => { + const progress = document.getElementById('uploadProgress'); + const progressBar = progress.querySelector('div'); + const statusText = document.getElementById('uploadStatus'); + progress.classList.remove('hidden'); + + try { + // Start the scanning process + const scanResponse = await fetch('/documents/scan', { + method: 'POST', + }); + + if (!scanResponse.ok) { + throw new Error('Scan failed to start'); + } + + // Start polling for progress + const pollInterval = setInterval(async () => { + const progressResponse = await fetch('/documents/scan-progress'); + const progressData = await progressResponse.json(); + + // Update progress bar + progressBar.style.width = `${progressData.progress}%`; + + // Update status text + if (progressData.total_files > 0) { + statusText.textContent = `Processing ${progressData.current_file} (${progressData.indexed_count}/${progressData.total_files})`; + } + + // Check if scanning is complete + if (!progressData.is_scanning) { + clearInterval(pollInterval); + progress.classList.add('hidden'); + statusText.textContent = 'Scan complete!'; + } + }, 1000); // Poll every second + + } catch (error) { + console.error('Upload error:', error); + progress.classList.add('hidden'); + statusText.textContent = 'Error during scanning process'; + } + }); + + + updateIndexedFiles(); + }, + + 'query': () => { + const queryBtn = document.getElementById('queryBtn'); + const queryInput = document.getElementById('queryInput'); + const queryMode = document.getElementById('queryMode'); + const queryResult = document.getElementById('queryResult'); + + let apiKey = localStorage.getItem('apiKey') || ''; + + queryBtn.addEventListener('click', async () => { + const query = queryInput.value.trim(); + if (!query) { + showToast('Please enter a query'); + return; + } + + queryBtn.disabled = true; + queryBtn.innerHTML = ` + + + + + Processing... + `; + + try { + const response = await fetchWithAuth('/query', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query, + mode: queryMode.value, + stream: false, + only_need_context: false + }) + }); + + const data = await response.json(); + queryResult.innerHTML = marked.parse(data.response); + } catch (error) { + showToast('Error processing query'); + } finally { + queryBtn.disabled = false; + queryBtn.textContent = 'Send Query'; + } + }); + }, + + 'status': async () => { + const healthStatus = document.getElementById('healthStatus'); + const configStatus = document.getElementById('configStatus'); + + try { + const response = await fetchWithAuth('/health'); + const data = await response.json(); + + healthStatus.innerHTML = ` +
+
+
+ ${data.status} +
+
+

Working Directory: ${data.working_directory}

+

Input Directory: ${data.input_directory}

+

Indexed Files: ${data.indexed_files_count}

+
+
+ `; + + configStatus.innerHTML = Object.entries(data.configuration) + .map(([key, value]) => ` +
+ ${key}: + ${value} +
+ `).join(''); + } catch (error) { + showToast('Error fetching status'); + } + }, + + 'settings': () => { + const saveBtn = document.getElementById('saveSettings'); + const apiKeyInput = document.getElementById('apiKeyInput'); + + saveBtn.addEventListener('click', () => { + state.apiKey = apiKeyInput.value; + localStorage.setItem('apiKey', state.apiKey); + showToast('Settings saved successfully'); + }); + } +}; + +// Navigation handling +document.querySelectorAll('.nav-item').forEach(item => { + item.addEventListener('click', (e) => { + e.preventDefault(); + const page = item.dataset.page; + document.getElementById('content').innerHTML = pages[page](); + if (handlers[page]) handlers[page](); + state.currentPage = page; + }); +}); + +// Initialize with file manager +document.getElementById('content').innerHTML = pages['file-manager'](); +handlers['file-manager'](); + +// Global functions +window.removeFile = (fileName) => { + state.files = state.files.filter(file => file.name !== fileName); + document.querySelector('#fileList div').innerHTML = state.files.map(file => ` +
+ ${file.name} + +
+ `).join(''); }; diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py new file mode 100644 index 00000000..fc6aa779 --- /dev/null +++ b/lightrag/kg/faiss_impl.py @@ -0,0 +1,323 @@ +import os +import time +import asyncio +import faiss +import json +import numpy as np +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass + +from lightrag.utils import ( + logger, + compute_mdhash_id, +) +from lightrag.base import ( + BaseVectorStorage, +) + + +@dataclass +class FaissVectorDBStorage(BaseVectorStorage): + """ + A Faiss-based Vector DB Storage for LightRAG. + Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. + """ + + cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + + def __post_init__(self): + # Grab config values if available + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + self.cosine_better_than_threshold = config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + # Where to save index file if you want persistent storage + self._faiss_index_file = os.path.join( + self.global_config["working_dir"], f"faiss_index_{self.namespace}.index" + ) + self._meta_file = self._faiss_index_file + ".meta.json" + + self._max_batch_size = self.global_config["embedding_batch_num"] + # Embedding dimension (e.g. 768) must match your embedding function + self._dim = self.embedding_func.embedding_dim + + # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity). + # If you have a large number of vectors, you might want IVF or other indexes. + # For demonstration, we use a simple IndexFlatIP. + self._index = faiss.IndexFlatIP(self._dim) + + # Keep a local store for metadata, IDs, etc. + # Maps → metadata (including your original ID). + self._id_to_meta = {} + + # Attempt to load an existing index + metadata from disk + self._load_faiss_index() + + async def upsert(self, data: dict[str, dict]): + """ + Insert or update vectors in the Faiss index. + + data: { + "custom_id_1": { + "content": , + ...metadata... + }, + "custom_id_2": { + "content": , + ...metadata... + }, + ... + } + """ + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not data: + logger.warning("You are inserting empty data to the vector DB") + return [] + + current_time = time.time() + + # Prepare data for embedding + list_data = [] + contents = [] + for k, v in data.items(): + # Store only known meta fields if needed + meta = {mf: v[mf] for mf in self.meta_fields if mf in v} + meta["__id__"] = k + meta["__created_at__"] = current_time + list_data.append(meta) + contents.append(v["content"]) + + # Split into batches for embedding if needed + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + pbar = tqdm_async( + total=len(batches), desc="Generating embeddings", unit="batch" + ) + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + embeddings_list = await asyncio.gather(*embedding_tasks) + + # Flatten the list of arrays + embeddings = np.concatenate(embeddings_list, axis=0) + if len(embeddings) != len(list_data): + logger.error( + f"Embedding size mismatch. Embeddings: {len(embeddings)}, Data: {len(list_data)}" + ) + return [] + + # Normalize embeddings for cosine similarity (in-place) + faiss.normalize_L2(embeddings) + + # Upsert logic: + # 1. Identify which vectors to remove if they exist + # 2. Remove them + # 3. Add the new vectors + existing_ids_to_remove = [] + for meta, emb in zip(list_data, embeddings): + faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) + if faiss_internal_id is not None: + existing_ids_to_remove.append(faiss_internal_id) + + if existing_ids_to_remove: + self._remove_faiss_ids(existing_ids_to_remove) + + # Step 2: Add new vectors + start_idx = self._index.ntotal + self._index.add(embeddings) + + # Step 3: Store metadata + vector for each new ID + for i, meta in enumerate(list_data): + fid = start_idx + i + # Store the raw vector so we can rebuild if something is removed + meta["__vector__"] = embeddings[i].tolist() + self._id_to_meta[fid] = meta + + logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") + return [m["__id__"] for m in list_data] + + async def query(self, query: str, top_k=5): + """ + Search by a textual query; returns top_k results with their metadata + similarity distance. + """ + embedding = await self.embedding_func([query]) + # embedding is shape (1, dim) + embedding = np.array(embedding, dtype=np.float32) + faiss.normalize_L2(embedding) # we do in-place normalization + + logger.info( + f"Query: {query}, top_k: {top_k}, threshold: {self.cosine_better_than_threshold}" + ) + + # Perform the similarity search + distances, indices = self._index.search(embedding, top_k) + + distances = distances[0] + indices = indices[0] + + results = [] + for dist, idx in zip(distances, indices): + if idx == -1: + # Faiss returns -1 if no neighbor + continue + + # Cosine similarity threshold + if dist < self.cosine_better_than_threshold: + continue + + meta = self._id_to_meta.get(idx, {}) + results.append( + { + **meta, + "id": meta.get("__id__"), + "distance": float(dist), + "created_at": meta.get("__created_at__"), + } + ) + + return results + + @property + def client_storage(self): + # Return whatever structure LightRAG might need for debugging + return {"data": list(self._id_to_meta.values())} + + async def delete(self, ids: list[str]): + """ + Delete vectors for the provided custom IDs. + """ + logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") + to_remove = [] + for cid in ids: + fid = self._find_faiss_id_by_custom_id(cid) + if fid is not None: + to_remove.append(fid) + + if to_remove: + self._remove_faiss_ids(to_remove) + logger.info( + f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" + ) + + async def delete_entity(self, entity_name: str): + """ + Delete a single entity by computing its hashed ID + the same way your code does it with `compute_mdhash_id`. + """ + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") + await self.delete([entity_id]) + + async def delete_entity_relation(self, entity_name: str): + """ + Delete relations for a given entity by scanning metadata. + """ + logger.debug(f"Searching relations for entity {entity_name}") + relations = [] + for fid, meta in self._id_to_meta.items(): + if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: + relations.append(fid) + + logger.debug(f"Found {len(relations)} relations for {entity_name}") + if relations: + self._remove_faiss_ids(relations) + logger.debug(f"Deleted {len(relations)} relations for {entity_name}") + + async def index_done_callback(self): + """ + Called after indexing is done (save Faiss index + metadata). + """ + self._save_faiss_index() + logger.info("Faiss index saved successfully.") + + # -------------------------------------------------------------------------------- + # Internal helper methods + # -------------------------------------------------------------------------------- + + def _find_faiss_id_by_custom_id(self, custom_id: str): + """ + Return the Faiss internal ID for a given custom ID, or None if not found. + """ + for fid, meta in self._id_to_meta.items(): + if meta.get("__id__") == custom_id: + return fid + return None + + def _remove_faiss_ids(self, fid_list): + """ + Remove a list of internal Faiss IDs from the index. + Because IndexFlatIP doesn't support 'removals', + we rebuild the index excluding those vectors. + """ + keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] + + # Rebuild the index + vectors_to_keep = [] + new_id_to_meta = {} + for new_fid, old_fid in enumerate(keep_fids): + vec_meta = self._id_to_meta[old_fid] + vectors_to_keep.append(vec_meta["__vector__"]) # stored as list + new_id_to_meta[new_fid] = vec_meta + + # Re-init index + self._index = faiss.IndexFlatIP(self._dim) + if vectors_to_keep: + arr = np.array(vectors_to_keep, dtype=np.float32) + self._index.add(arr) + + self._id_to_meta = new_id_to_meta + + def _save_faiss_index(self): + """ + Save the current Faiss index + metadata to disk so it can persist across runs. + """ + faiss.write_index(self._index, self._faiss_index_file) + + # Save metadata dict to JSON. Convert all keys to strings for JSON storage. + # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } + # We'll keep the int -> dict, but JSON requires string keys. + serializable_dict = {} + for fid, meta in self._id_to_meta.items(): + serializable_dict[str(fid)] = meta + + with open(self._meta_file, "w", encoding="utf-8") as f: + json.dump(serializable_dict, f) + + def _load_faiss_index(self): + """ + Load the Faiss index + metadata from disk if it exists, + and rebuild in-memory structures so we can query. + """ + if not os.path.exists(self._faiss_index_file): + logger.warning("No existing Faiss index file found. Starting fresh.") + return + + try: + # Load the Faiss index + self._index = faiss.read_index(self._faiss_index_file) + # Load metadata + with open(self._meta_file, "r", encoding="utf-8") as f: + stored_dict = json.load(f) + + # Convert string keys back to int + self._id_to_meta = {} + for fid_str, meta in stored_dict.items(): + fid = int(fid_str) + self._id_to_meta[fid] = meta + + logger.info( + f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}" + ) + except Exception as e: + logger.error(f"Failed to load Faiss index or metadata: {e}") + logger.warning("Starting with an empty Faiss index.") + self._index = faiss.IndexFlatIP(self._dim) + self._id_to_meta = {} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 92fc954f..22db6994 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -60,6 +60,7 @@ STORAGES = { "PGGraphStorage": ".kg.postgres_impl", "GremlinStorage": ".kg.gremlin_impl", "PGDocStatusStorage": ".kg.postgres_impl", + "FaissVectorDBStorage": ".kg.faiss_impl", }