diff --git a/lightrag/storage.py b/lightrag/storage.py index 3bee911b..91ba7bcc 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -1,460 +1 @@ -import asyncio -import html -import os -from tqdm.asyncio import tqdm as tqdm_async -from dataclasses import dataclass -from typing import Any, Union, cast, Dict -import networkx as nx -import numpy as np - -from nano_vectordb import NanoVectorDB -import time - -from .utils import ( - logger, - load_json, - write_json, - compute_mdhash_id, -) - -from .base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, - DocStatus, - DocProcessingStatus, - DocStatusStorage, -) - - -@dataclass -class JsonKVStorage(BaseKVStorage): - def __post_init__(self): - working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} - self._lock = asyncio.Lock() - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - - async def all_keys(self) -> list[str]: - return list(self._data.keys()) - - async def index_done_callback(self): - write_json(self._data, self._file_name) - - async def get_by_id(self, id): - return self._data.get(id, None) - - async def get_by_ids(self, ids, fields=None): - if fields is None: - return [self._data.get(id, None) for id in ids] - return [ - ( - {k: v for k, v in self._data[id].items() if k in fields} - if self._data.get(id, None) - else None - ) - for id in ids - ] - - async def filter_keys(self, data: list[str]) -> set[str]: - return set([s for s in data if s not in self._data]) - - async def upsert(self, data: dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - return left_data - - async def drop(self): - self._data = {} - - async def filter(self, filter_func): - """Filter key-value pairs based on a filter function - - Args: - filter_func: The filter function, which takes a value as an argument and returns a boolean value - - Returns: - Dict: Key-value pairs that meet the condition - """ - result = {} - async with self._lock: - for key, value in self._data.items(): - if filter_func(value): - result[key] = value - return result - - async def delete(self, ids: list[str]): - """Delete data with specified IDs - - Args: - ids: List of IDs to delete - """ - async with self._lock: - for id in ids: - if id in self._data: - del self._data[id] - await self.index_done_callback() - logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}") - - -@dataclass -class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = 0.2 - - def __post_init__(self): - self._client_file_name = os.path.join( - self.global_config["working_dir"], f"vdb_{self.namespace}.json" - ) - self._max_batch_size = self.global_config["embedding_batch_num"] - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name - ) - self.cosine_better_than_threshold = self.global_config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) - - async def upsert(self, data: dict[str, dict]): - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") - if not len(data): - logger.warning("You insert an empty data to vector DB") - return [] - - current_time = time.time() - list_data = [ - { - "__id__": k, - "__created_at__": current_time, - **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - - async def wrapped_task(batch): - result = await self.embedding_func(batch) - pbar.update(1) - return result - - embedding_tasks = [wrapped_task(batch) for batch in batches] - pbar = tqdm_async( - total=len(embedding_tasks), desc="Generating embeddings", unit="batch" - ) - embeddings_list = await asyncio.gather(*embedding_tasks) - - embeddings = np.concatenate(embeddings_list) - if len(embeddings) == len(list_data): - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - results = self._client.upsert(datas=list_data) - return results - else: - # sometimes the embedding is not returned correctly. just log it. - logger.error( - f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" - ) - - async def query(self, query: str, top_k=5): - embedding = await self.embedding_func([query]) - embedding = embedding[0] - results = self._client.query( - query=embedding, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) - results = [ - { - **dp, - "id": dp["__id__"], - "distance": dp["__metrics__"], - "created_at": dp.get("__created_at__"), - } - for dp in results - ] - return results - - @property - def client_storage(self): - return getattr(self._client, "_NanoVectorDB__storage") - - async def delete(self, ids: list[str]): - """Delete vectors with specified IDs - - Args: - ids: List of vector IDs to be deleted - """ - try: - self._client.delete(ids) - logger.info( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - - async def delete_entity(self, entity_name: str): - try: - entity_id = compute_mdhash_id(entity_name, prefix="ent-") - logger.debug( - f"Attempting to delete entity {entity_name} with ID {entity_id}" - ) - # Check if the entity exists - if self._client.get([entity_id]): - await self.delete([entity_id]) - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") - except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") - - async def delete_entity_relation(self, entity_name: str): - try: - relations = [ - dp - for dp in self.client_storage["data"] - if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name - ] - logger.debug(f"Found {len(relations)} relations for entity {entity_name}") - ids_to_delete = [relation["__id__"] for relation in relations] - - if ids_to_delete: - await self.delete(ids_to_delete) - logger.debug( - f"Deleted {len(ids_to_delete)} relations for {entity_name}" - ) - else: - logger.debug(f"No relations found for entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting relations for {entity_name}: {e}") - - async def index_done_callback(self): - self._client.save() - - -@dataclass -class NetworkXStorage(BaseGraphStorage): - @staticmethod - def load_nx_graph(file_name) -> nx.Graph: - if os.path.exists(file_name): - return nx.read_graphml(file_name) - return None - - @staticmethod - def write_nx_graph(graph: nx.Graph, file_name): - logger.info( - f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" - ) - nx.write_graphml(graph, file_name) - - @staticmethod - def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: - """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - Return the largest connected component of the graph, with nodes and edges sorted in a stable way. - """ - from graspologic.utils import largest_connected_component - - graph = graph.copy() - graph = cast(nx.Graph, largest_connected_component(graph)) - node_mapping = { - node: html.unescape(node.upper().strip()) for node in graph.nodes() - } # type: ignore - graph = nx.relabel_nodes(graph, node_mapping) - return NetworkXStorage._stabilize_graph(graph) - - @staticmethod - def _stabilize_graph(graph: nx.Graph) -> nx.Graph: - """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py - Ensure an undirected graph with the same relationships will always be read the same way. - """ - fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() - - sorted_nodes = graph.nodes(data=True) - sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) - - fixed_graph.add_nodes_from(sorted_nodes) - edges = list(graph.edges(data=True)) - - if not graph.is_directed(): - - def _sort_source_target(edge): - source, target, edge_data = edge - if source > target: - temp = source - source = target - target = temp - return source, target, edge_data - - edges = [_sort_source_target(edge) for edge in edges] - - def _get_edge_key(source: Any, target: Any) -> str: - return f"{source} -> {target}" - - edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) - - fixed_graph.add_edges_from(edges) - return fixed_graph - - def __post_init__(self): - self._graphml_xml_file = os.path.join( - self.global_config["working_dir"], f"graph_{self.namespace}.graphml" - ) - preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" - ) - self._graph = preloaded_graph or nx.Graph() - self._node_embed_algorithms = { - "node2vec": self._node2vec_embed, - } - - async def index_done_callback(self): - NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) - - async def has_node(self, node_id: str) -> bool: - return self._graph.has_node(node_id) - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - return self._graph.has_edge(source_node_id, target_node_id) - - async def get_node(self, node_id: str) -> Union[dict, None]: - return self._graph.nodes.get(node_id) - - async def node_degree(self, node_id: str) -> int: - return self._graph.degree(node_id) - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - return self._graph.degree(src_id) + self._graph.degree(tgt_id) - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - return self._graph.edges.get((source_node_id, target_node_id)) - - async def get_node_edges(self, source_node_id: str): - if self._graph.has_node(source_node_id): - return list(self._graph.edges(source_node_id)) - return None - - async def upsert_node(self, node_id: str, node_data: dict[str, str]): - self._graph.add_node(node_id, **node_data) - - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): - self._graph.add_edge(source_node_id, target_node_id, **edge_data) - - async def delete_node(self, node_id: str): - """ - Delete a node from the graph based on the specified node_id. - - :param node_id: The node_id to delete - """ - if self._graph.has_node(node_id): - self._graph.remove_node(node_id) - logger.info(f"Node {node_id} deleted from the graph.") - else: - logger.warning(f"Node {node_id} not found in the graph for deletion.") - - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: - if algorithm not in self._node_embed_algorithms: - raise ValueError(f"Node embedding algorithm {algorithm} not supported") - return await self._node_embed_algorithms[algorithm]() - - # @TODO: NOT USED - async def _node2vec_embed(self): - from graspologic import embed - - embeddings, nodes = embed.node2vec_embed( - self._graph, - **self.global_config["node2vec_params"], - ) - - nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] - return embeddings, nodes_ids - - def remove_nodes(self, nodes: list[str]): - """Delete multiple nodes - - Args: - nodes: List of node IDs to be deleted - """ - for node in nodes: - if self._graph.has_node(node): - self._graph.remove_node(node) - - def remove_edges(self, edges: list[tuple[str, str]]): - """Delete multiple edges - - Args: - edges: List of edges to be deleted, each edge is a (source, target) tuple - """ - for source, target in edges: - if self._graph.has_edge(source, target): - self._graph.remove_edge(source, target) - - -@dataclass -class JsonDocStatusStorage(DocStatusStorage): - """JSON implementation of document status storage""" - - def __post_init__(self): - working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} - logger.info(f"Loaded document status storage with {len(self._data)} records") - - async def filter_keys(self, data: list[str]) -> set[str]: - """Return keys that should be processed (not in storage or not successfully processed)""" - return set( - [ - k - for k in data - if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED - ] - ) - - async def get_status_counts(self) -> Dict[str, int]: - """Get counts of documents in each status""" - counts = {status: 0 for status in DocStatus} - for doc in self._data.values(): - counts[doc["status"]] += 1 - return counts - - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} - - async def index_done_callback(self): - """Save data to file after indexing""" - write_json(self._data, self._file_name) - - async def upsert(self, data: dict[str, dict]): - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ - self._data.update(data) - await self.index_done_callback() - return data - - async def get_by_id(self, id: str): - return self._data.get(id) - - async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: - """Get document status by ID""" - return self._data.get(doc_id) - - async def delete(self, doc_ids: list[str]): - """Delete document status by IDs""" - for doc_id in doc_ids: - self._data.pop(doc_id, None) - await self.index_done_callback() +# This file is not needed anymore (TODO: remove)