diff --git a/lightrag/storage/nano_vector_db.py b/lightrag/storage/nano_vector_db.py new file mode 100644 index 00000000..1499076c --- /dev/null +++ b/lightrag/storage/nano_vector_db.py @@ -0,0 +1,217 @@ +""" +NanoVectorDB Storage Module +======================= + +This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. + +The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. + +Author: lightrag team +Created: 2024-01-25 +License: MIT + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Version: 1.0.0 + +Dependencies: + - NetworkX + - NumPy + - LightRAG + - graspologic + +Features: + - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) + - Query graph nodes and edges + - Calculate node and edge degrees + - Embed nodes using various algorithms (e.g., Node2Vec) + - Remove nodes and edges from the graph + +Usage: + from lightrag.storage.networkx_storage import NetworkXStorage + +""" + + +import asyncio +import html +import os +from tqdm.asyncio import tqdm as tqdm_async +from dataclasses import dataclass +from typing import Any, Union, cast, Dict +import numpy as np +import pipmaster as pm + +if not pm.is_installed("nano-vectordb"): + pm.install("nano-vectordb") + +from nano_vectordb import NanoVectorDB +import time + +from lightrag.utils import ( + logger, + load_json, + write_json, + compute_mdhash_id, +) + +from lightrag.base import ( + BaseGraphStorage, + BaseKVStorage, + BaseVectorStorage, + DocStatus, + DocProcessingStatus, + DocStatusStorage, +) + + +@dataclass +class NanoVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = 0.2 + + def __post_init__(self): + self._client_file_name = os.path.join( + self.global_config["working_dir"], f"vdb_{self.namespace}.json" + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) + self.cosine_better_than_threshold = self.global_config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + + current_time = time.time() + list_data = [ + { + "__id__": k, + "__created_at__": current_time, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + async def wrapped_task(batch): + result = await self.embedding_func(batch) + pbar.update(1) + return result + + embedding_tasks = [wrapped_task(batch) for batch in batches] + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + if len(embeddings) == len(list_data): + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + results = self._client.upsert(datas=list_data) + return results + else: + # sometimes the embedding is not returned correctly. just log it. + logger.error( + f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" + ) + + async def query(self, query: str, top_k=5): + embedding = await self.embedding_func([query]) + embedding = embedding[0] + results = self._client.query( + query=embedding, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + results = [ + { + **dp, + "id": dp["__id__"], + "distance": dp["__metrics__"], + "created_at": dp.get("__created_at__"), + } + for dp in results + ] + return results + + @property + def client_storage(self): + return getattr(self._client, "_NanoVectorDB__storage") + + async def delete(self, ids: list[str]): + """Delete vectors with specified IDs + + Args: + ids: List of vector IDs to be deleted + """ + try: + self._client.delete(ids) + logger.info( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + + async def delete_entity(self, entity_name: str): + try: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + logger.debug( + f"Attempting to delete entity {entity_name} with ID {entity_id}" + ) + # Check if the entity exists + if self._client.get([entity_id]): + await self.delete([entity_id]) + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + + async def delete_entity_relation(self, entity_name: str): + try: + relations = [ + dp + for dp in self.client_storage["data"] + if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name + ] + logger.debug(f"Found {len(relations)} relations for entity {entity_name}") + ids_to_delete = [relation["__id__"] for relation in relations] + + if ids_to_delete: + await self.delete(ids_to_delete) + logger.debug( + f"Deleted {len(ids_to_delete)} relations for {entity_name}" + ) + else: + logger.debug(f"No relations found for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for {entity_name}: {e}") + + async def index_done_callback(self): + self._client.save()