diff --git a/lightrag/base.py b/lightrag/base.py index 6a7374ac..5de3d423 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,9 +1,10 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from enum import StrEnum import os from dotenv import load_dotenv from dataclasses import dataclass, field -from enum import Enum from typing import ( Any, Literal, @@ -82,138 +83,130 @@ class QueryParam: @dataclass -class StorageNameSpace: +class StorageNameSpace(ABC): namespace: str global_config: dict[str, Any] + @abstractmethod async def index_done_callback(self) -> None: """Commit the storage operations after indexing""" - pass @dataclass -class BaseVectorStorage(StorageNameSpace): +class BaseVectorStorage(StorageNameSpace, ABC): embedding_func: EmbeddingFunc + cosine_better_than_threshold: float = field(default=0.2) meta_fields: set[str] = field(default_factory=set) + @abstractmethod async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: - raise NotImplementedError + """Query the vector storage and retrieve top_k results.""" + @abstractmethod async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - """Use 'content' field from value for embedding, use key as id. - If embedding_func is None, use 'embedding' field from value - """ - raise NotImplementedError + """Insert or update vectors in the storage.""" + @abstractmethod async def delete_entity(self, entity_name: str) -> None: - """Delete a single entity by its name""" - raise NotImplementedError + """Delete a single entity by its name.""" + @abstractmethod async def delete_entity_relation(self, entity_name: str) -> None: - """Delete relations for a given entity by scanning metadata""" - raise NotImplementedError + """Delete relations for a given entity.""" @dataclass -class BaseKVStorage(StorageNameSpace): - embedding_func: EmbeddingFunc | None = None +class BaseKVStorage(StorageNameSpace, ABC): + embedding_func: EmbeddingFunc + @abstractmethod async def get_by_id(self, id: str) -> dict[str, Any] | None: - raise NotImplementedError + """Get value by id""" + @abstractmethod async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - raise NotImplementedError + """Get values by ids""" - async def filter_keys(self, data: set[str]) -> set[str]: + @abstractmethod + async def filter_keys(self, keys: set[str]) -> set[str]: """Return un-exist keys""" - raise NotImplementedError - async def upsert(self, data: dict[str, Any]) -> None: - raise NotImplementedError + @abstractmethod + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """Upsert data""" + @abstractmethod async def drop(self) -> None: - raise NotImplementedError + """Drop the storage""" @dataclass -class BaseGraphStorage(StorageNameSpace): - embedding_func: EmbeddingFunc | None = None - """Check if a node exists in the graph.""" +class BaseGraphStorage(StorageNameSpace, ABC): + embedding_func: EmbeddingFunc + @abstractmethod async def has_node(self, node_id: str) -> bool: - raise NotImplementedError - - """Check if an edge exists in the graph.""" + """Check if an edge exists in the graph.""" + @abstractmethod async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - raise NotImplementedError - - """Get the degree of a node.""" + """Get the degree of a node.""" + @abstractmethod async def node_degree(self, node_id: str) -> int: - raise NotImplementedError - - """Get the degree of an edge.""" + """Get the degree of an edge.""" + @abstractmethod async def edge_degree(self, src_id: str, tgt_id: str) -> int: - raise NotImplementedError - - """Get a node by its id.""" + """Get a node by its id.""" + @abstractmethod async def get_node(self, node_id: str) -> dict[str, str] | None: - raise NotImplementedError - - """Get an edge by its source and target node ids.""" + """Get an edge by its source and target node ids.""" + @abstractmethod async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - raise NotImplementedError - - """Get all edges connected to a node.""" + """Get all edges connected to a node.""" + @abstractmethod async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - raise NotImplementedError - - """Upsert a node into the graph.""" + """Upsert a node into the graph.""" + @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - raise NotImplementedError - - """Upsert an edge into the graph.""" + """Upsert an edge into the graph.""" + @abstractmethod async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: - raise NotImplementedError - - """Delete a node from the graph.""" + """Delete a node from the graph.""" + @abstractmethod async def delete_node(self, node_id: str) -> None: - raise NotImplementedError - - """Embed nodes using an algorithm.""" + """Embed nodes using an algorithm.""" + @abstractmethod async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError("Node embedding is not used in lightrag.") - - """Get all labels in the graph.""" + """Get all labels in the graph.""" + @abstractmethod async def get_all_labels(self) -> list[str]: - raise NotImplementedError - - """Get a knowledge graph of a node.""" + """Get a knowledge graph of a node.""" + @abstractmethod async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: - raise NotImplementedError + """Retrieve a subgraph of the knowledge graph starting from a given node.""" -class DocStatus(str, Enum): - """Document processing status enum""" +class DocStatus(StrEnum): + """Document processing status""" PENDING = "pending" PROCESSING = "processing" @@ -245,19 +238,16 @@ class DocProcessingStatus: """Additional metadata""" -class DocStatusStorage(BaseKVStorage): +@dataclass +class DocStatusStorage(BaseKVStorage, ABC): """Base class for document status storage""" + @abstractmethod async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" - raise NotImplementedError + @abstractmethod async def get_docs_by_status( self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" - raise NotImplementedError - - async def update_doc_status(self, data: dict[str, Any]) -> None: - """Updates the status of a document. By default, it calls upsert.""" - await self.upsert(data) diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index a6857f22..243a110b 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -5,19 +5,11 @@ import os import sys from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union, final +import numpy as np import pipmaster as pm +from lightrag.types import KnowledgeGraph -if not pm.is_installed("psycopg-pool"): - pm.install("psycopg-pool") - pm.install("psycopg[binary,pool]") -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - - -import psycopg -from psycopg.rows import namedtuple_row -from psycopg_pool import AsyncConnectionPool, PoolTimeout from tenacity import ( retry, retry_if_exception_type, @@ -35,6 +27,23 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +if not pm.is_installed("psycopg-pool"): + pm.install("psycopg-pool") + pm.install("psycopg[binary,pool]") + +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + +try: + import psycopg + from psycopg.rows import namedtuple_row + from psycopg_pool import AsyncConnectionPool, PoolTimeout +except ImportError: + raise ImportError( + "`psycopg-pool, psycopg[binary,pool], asyncpg` library is not installed. Please install it via pip: `pip install psycopg-pool psycopg[binary,pool] asyncpg`." + ) + + class AGEQueryException(Exception): """Exception for the AGE queries.""" @@ -53,6 +62,7 @@ class AGEQueryException(Exception): return self.details +@final @dataclass class AGEStorage(BaseGraphStorage): @staticmethod @@ -98,9 +108,6 @@ class AGEStorage(BaseGraphStorage): if self._driver: await self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") - @staticmethod def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: """ @@ -396,7 +403,7 @@ class AGEStorage(BaseGraphStorage): ) return single_result["edge_exists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: entity_name_label = node_id.strip('"') query = """ MATCH (n:`{label}`) RETURN n @@ -454,17 +461,7 @@ class AGEStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given labels - - Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes - - Returns: - list: List of all relationships/edges found - """ + ) -> dict[str, str] | None: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -488,7 +485,7 @@ class AGEStorage(BaseGraphStorage): ) return result - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information @@ -526,7 +523,7 @@ class AGEStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((AGEQueryException,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the AGE database. @@ -562,8 +559,8 @@ class AGEStorage(BaseGraphStorage): retry=retry_if_exception_type((AGEQueryException,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -619,3 +616,23 @@ class AGEStorage(BaseGraphStorage): yield connection finally: await self._driver.putconn(connection) + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + + async def index_done_callback(self) -> None: + # AGES handles persistence automatically + pass diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index cb3b59f1..62a9b601 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,19 +1,29 @@ import asyncio from dataclasses import dataclass -from typing import Union +from typing import Any, final import numpy as np -from chromadb import HttpClient, PersistentClient -from chromadb.config import Settings + from lightrag.base import BaseVectorStorage from lightrag.utils import logger +import pipmaster as pm + +if not pm.is_installed("chromadb"): + pm.install("chromadb") + +try: + from chromadb import HttpClient, PersistentClient + from chromadb.config import Settings +except ImportError as e: + raise ImportError( + "`chromadb` library is not installed. Please install it via pip: `pip install chromadb`." + ) from e +@final @dataclass class ChromaVectorDBStorage(BaseVectorStorage): """ChromaDB vector storage implementation.""" - cosine_better_than_threshold: float = None - def __post_init__(self): try: config = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -102,7 +112,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"ChromaDB initialization failed: {str(e)}") raise - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: logger.warning("Empty data provided to vector DB") return [] @@ -151,7 +161,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB upsert: {str(e)}") raise - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: try: embedding = await self.embedding_func([query]) @@ -183,6 +193,12 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB query: {str(e)}") raise - async def index_done_callback(self): + async def index_done_callback(self) -> None: # ChromaDB handles persistence automatically pass + + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index ba94e39e..2b67e2fa 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -1,11 +1,13 @@ import os import time import asyncio -import faiss +from typing import Any, final + import json import numpy as np -from tqdm.asyncio import tqdm as tqdm_async + from dataclasses import dataclass +import pipmaster as pm from lightrag.utils import ( logger, @@ -15,7 +17,19 @@ from lightrag.base import ( BaseVectorStorage, ) +if not pm.is_installed("faiss"): + pm.install("faiss") +try: + import faiss + from tqdm.asyncio import tqdm as tqdm_async +except ImportError as e: + raise ImportError( + "`faiss` library is not installed. Please install it via pip: `pip install faiss`." + ) from e + + +@final @dataclass class FaissVectorDBStorage(BaseVectorStorage): """ @@ -23,8 +37,6 @@ class FaissVectorDBStorage(BaseVectorStorage): Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. """ - cosine_better_than_threshold: float = None - def __post_init__(self): # Grab config values if available kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -57,7 +69,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Attempt to load an existing index + metadata from disk self._load_faiss_index() - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -147,7 +159,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") return [m["__id__"] for m in list_data] - async def query(self, query: str, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """ Search by a textual query; returns top_k results with their metadata + similarity distance. """ @@ -210,11 +222,7 @@ class FaissVectorDBStorage(BaseVectorStorage): f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" ) - async def delete_entity(self, entity_name: str): - """ - Delete a single entity by computing its hashed ID - the same way your code does it with `compute_mdhash_id`. - """ + async def delete_entity(self, entity_name: str) -> None: entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}") await self.delete([entity_id]) @@ -234,12 +242,8 @@ class FaissVectorDBStorage(BaseVectorStorage): self._remove_faiss_ids(relations) logger.debug(f"Deleted {len(relations)} relations for {entity_name}") - async def index_done_callback(self): - """ - Called after indexing is done (save Faiss index + metadata). - """ + async def index_done_callback(self) -> None: self._save_faiss_index() - logger.info("Faiss index saved successfully.") # -------------------------------------------------------------------------------- # Internal helper methods diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index f38fd00a..40a9f007 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -3,11 +3,11 @@ import inspect import json import os from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, final + +import numpy as np + -from gremlin_python.driver import client, serializer -from gremlin_python.driver.aiohttp.transport import AiohttpTransport -from gremlin_python.driver.protocol import GremlinServerError from tenacity import ( retry, retry_if_exception_type, @@ -15,11 +15,22 @@ from tenacity import ( wait_exponential, ) +from lightrag.types import KnowledgeGraph from lightrag.utils import logger from ..base import BaseGraphStorage +try: + from gremlin_python.driver import client, serializer + from gremlin_python.driver.aiohttp.transport import AiohttpTransport + from gremlin_python.driver.protocol import GremlinServerError +except ImportError as e: + raise ImportError( + "`gremlin` library is not installed. Please install it via pip: `pip install gremlin`." + ) from e + +@final @dataclass class GremlinStorage(BaseGraphStorage): @staticmethod @@ -76,8 +87,9 @@ class GremlinStorage(BaseGraphStorage): if self._driver: self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + # Gremlin handles persistence automatically + pass @staticmethod def _to_value_map(value: Any) -> str: @@ -190,7 +202,7 @@ class GremlinStorage(BaseGraphStorage): return result[0]["has_edge"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: entity_name = GremlinStorage._fix_name(node_id) query = f"""g .V().has('graph', {self.graph_name}) @@ -252,17 +264,7 @@ class GremlinStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given names - - Args: - source_node_id (str): Name of the source nodes - target_node_id (str): Name of the target nodes - - Returns: - dict|None: Dict of found edge properties, or None if not found - """ + ) -> dict[str, str] | None: entity_name_source = GremlinStorage._fix_name(source_node_id) entity_name_target = GremlinStorage._fix_name(target_node_id) query = f"""g @@ -286,11 +288,7 @@ class GremlinStorage(BaseGraphStorage): ) return edge_properties - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: - """ - Retrieves all edges (relationships) for a particular node identified by its name. - :return: List of tuples containing edge sources and targets - """ + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: node_name = GremlinStorage._fix_name(source_node_id) query = f"""g .E() @@ -316,7 +314,7 @@ class GremlinStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((GremlinServerError,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the Gremlin graph. @@ -357,8 +355,8 @@ class GremlinStorage(BaseGraphStorage): retry=retry_if_exception_type((GremlinServerError,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their names. @@ -397,3 +395,19 @@ class GremlinStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index ed79a370..6c667891 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -1,56 +1,6 @@ -""" -JsonDocStatus Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - from dataclasses import dataclass import os -from typing import Any, Union +from typing import Any, Union, final from lightrag.base import ( DocProcessingStatus, @@ -64,6 +14,7 @@ from lightrag.utils import ( ) +@final @dataclass class JsonDocStatusStorage(DocStatusStorage): """JSON implementation of document status storage""" @@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage): self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Loaded document status storage with {len(self._data)} records") - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - return set(data) - set(self._data.keys()) + return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] @@ -88,7 +39,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" - counts = {status: 0 for status in DocStatus} + counts = {status.value: 0 for status in DocStatus} for doc in self._data.values(): counts[doc["status"]] += 1 return counts @@ -96,23 +47,17 @@ class JsonDocStatusStorage(DocStatusStorage): async def get_docs_by_status( self, status: DocStatus ) -> dict[str, DocProcessingStatus]: - """all documents with a specific status""" + """Get all documents with a specific status""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() - if v["status"] == status + if v["status"] == status.value } - async def index_done_callback(self): - """Save data to file after indexing""" + async def index_done_callback(self) -> None: write_json(self._data, self._file_name) - async def upsert(self, data: dict[str, Any]) -> None: - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: self._data.update(data) await self.index_done_callback() @@ -120,7 +65,9 @@ class JsonDocStatusStorage(DocStatusStorage): return self._data.get(id) async def delete(self, doc_ids: list[str]): - """Delete document status by IDs""" for doc_id in doc_ids: self._data.pop(doc_id, None) await self.index_done_callback() + + async def drop(self) -> None: + raise NotImplementedError diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 3ab5b966..658e1239 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,7 +1,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any, final from lightrag.base import ( BaseKVStorage, @@ -13,6 +13,7 @@ from lightrag.utils import ( ) +@final @dataclass class JsonKVStorage(BaseKVStorage): def __post_init__(self): @@ -22,10 +23,10 @@ class JsonKVStorage(BaseKVStorage): self._lock = asyncio.Lock() logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - async def index_done_callback(self): + async def index_done_callback(self) -> None: write_json(self._data, self._file_name) - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: return self._data.get(id) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -38,8 +39,8 @@ class JsonKVStorage(BaseKVStorage): for id in ids ] - async def filter_keys(self, data: set[str]) -> set[str]: - return set(data) - set(self._data.keys()) + async def filter_keys(self, keys: set[str]) -> set[str]: + return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index f4d9d47f..3e8f1ba5 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1,5 +1,6 @@ import asyncio import os +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -10,17 +11,21 @@ import configparser if not pm.is_installed("pymilvus"): pm.install("pymilvus") -from pymilvus import MilvusClient +try: + from pymilvus import MilvusClient +except ImportError as e: + raise ImportError( + "`pymilvus` library is not installed. Please install it via pip: `pip install pymilvus`." + ) from e config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class MilvusVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - @staticmethod def create_collection_if_not_exist( client: MilvusClient, collection_name: str, **kwargs @@ -71,7 +76,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): dimension=self.embedding_func.embedding_dim, ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -106,7 +111,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): results = self._client.upsert(collection_name=self.namespace, data=list_data) return results - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, @@ -123,3 +128,13 @@ class MilvusVectorDBStorage(BaseVectorStorage): {**dp["entity"], "id": dp["id"], "distance": dp["distance"]} for dp in results[0] ] + + async def index_done_callback(self) -> None: + # Milvus handles persistence automatically + pass + + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index f6326b76..4eb968cf 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,22 +1,11 @@ import os from dataclasses import dataclass import numpy as np -import pipmaster as pm import configparser from tqdm.asyncio import tqdm as tqdm_async import asyncio -if not pm.is_installed("pymongo"): - pm.install("pymongo") - -if not pm.is_installed("motor"): - pm.install("motor") - -from typing import Any, List, Tuple, Union -from motor.motor_asyncio import AsyncIOMotorClient -from pymongo import MongoClient -from pymongo.operations import SearchIndexModel -from pymongo.errors import PyMongoError +from typing import Any, List, Union, final from ..base import ( BaseGraphStorage, @@ -29,12 +18,29 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +import pipmaster as pm +if not pm.is_installed("pymongo"): + pm.install("pymongo") + +if not pm.is_installed("motor"): + pm.install("motor") + +try: + from motor.motor_asyncio import AsyncIOMotorClient + from pymongo import MongoClient + from pymongo.operations import SearchIndexModel + from pymongo.errors import PyMongoError +except ImportError as e: + raise ImportError( + "`motor, pymongo` library is not installed. Please install it via pip: `pip install motor pymongo`." + ) from e config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): @@ -60,17 +66,17 @@ class MongoKVStorage(BaseKVStorage): # Ensure collection exists create_collection_if_not_exists(uri, database.name, self._collection_name) - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: return await self._data.find_one({"_id": id}) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: cursor = self._data.find({"_id": {"$in": ids}}) return await cursor.to_list() - async def filter_keys(self, data: set[str]) -> set[str]: - cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + async def filter_keys(self, keys: set[str]) -> set[str]: + cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1}) existing_ids = {str(x["_id"]) async for x in cursor} - return data - existing_ids + return keys - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -107,11 +113,16 @@ class MongoKVStorage(BaseKVStorage): else: return None + async def index_done_callback(self) -> None: + # Mongo handles persistence automatically + pass + async def drop(self) -> None: """Drop the collection""" await self._data.drop() +@final @dataclass class MongoDocStatusStorage(DocStatusStorage): def __post_init__(self): @@ -191,7 +202,12 @@ class MongoDocStatusStorage(DocStatusStorage): for doc in result } + async def index_done_callback(self) -> None: + # Mongo handles persistence automatically + pass + +@final @dataclass class MongoGraphStorage(BaseGraphStorage): """ @@ -429,7 +445,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """ Return the full node document (including "edges"), or None if missing. """ @@ -437,11 +453,7 @@ class MongoGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Return the first edge dict from source_node_id to target_node_id if it exists. - Uses a single-hop $graphLookup as demonstration, though a direct find is simpler. - """ + ) -> dict[str, str] | None: pipeline = [ {"$match": {"_id": source_node_id}}, { @@ -467,9 +479,7 @@ class MongoGraphStorage(BaseGraphStorage): return e return None - async def get_node_edges( - self, source_node_id: str - ) -> Union[List[Tuple[str, str]], None]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Return a list of (source_id, target_id) for direct edges from source_node_id. Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. @@ -503,7 +513,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def upsert_node(self, node_id: str, node_data: dict): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Insert or update a node document. If new, create an empty edges array. """ @@ -513,8 +523,8 @@ class MongoGraphStorage(BaseGraphStorage): await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge from source_node_id -> target_node_id with optional 'relation'. If an edge with the same target exists, we remove it and re-insert with updated data. @@ -540,7 +550,7 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def delete_node(self, node_id: str): + async def delete_node(self, node_id: str) -> None: """ 1) Remove node's doc entirely. 2) Remove inbound edges from any doc that references node_id. @@ -557,7 +567,9 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: """ Placeholder for demonstration, raises NotImplementedError. """ @@ -759,11 +771,14 @@ class MongoGraphStorage(BaseGraphStorage): return result + async def index_done_callback(self) -> None: + # Mongo handles persistence automatically + pass + +@final @dataclass class MongoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -828,7 +843,7 @@ class MongoVectorDBStorage(BaseVectorStorage): except PyMongoError as _: logger.debug("vector index already exist") - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") if not data: logger.warning("You are inserting an empty data set to vector DB") @@ -871,7 +886,7 @@ class MongoVectorDBStorage(BaseVectorStorage): return list_data - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Queries the vector database using Atlas Vector Search.""" # Generate the embedding embedding = await self.embedding_func([query]) @@ -905,6 +920,16 @@ class MongoVectorDBStorage(BaseVectorStorage): for doc in results ] + async def index_done_callback(self) -> None: + # Mongo handles persistence automatically + pass + + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError + def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): """Check if the collection exists. if not, create it.""" diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 96c5a8dd..16955d8a 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -1,80 +1,35 @@ -""" -NanoVectorDB Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - import asyncio import os +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np -import pipmaster as pm -if not pm.is_installed("nano-vectordb"): - pm.install("nano-vectordb") - -from nano_vectordb import NanoVectorDB import time from lightrag.utils import ( logger, compute_mdhash_id, ) - +import pipmaster as pm from lightrag.base import ( BaseVectorStorage, ) +if not pm.is_installed("nano-vectordb"): + pm.install("nano-vectordb") +try: + from nano_vectordb import NanoVectorDB +except ImportError as e: + raise ImportError( + "`nano-vectordb` library is not installed. Please install it via pip: `pip install nano-vectordb`." + ) from e + + +@final @dataclass class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - def __post_init__(self): # Initialize lock only for file operations self._save_lock = asyncio.Lock() @@ -95,7 +50,7 @@ class NanoVectorDBStorage(BaseVectorStorage): self.embedding_func.embedding_dim, storage_file=self._client_file_name ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -139,7 +94,7 @@ class NanoVectorDBStorage(BaseVectorStorage): f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" ) - async def query(self, query: str, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) embedding = embedding[0] results = self._client.query( @@ -176,7 +131,7 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - async def delete_entity(self, entity_name: str): + async def delete_entity(self, entity_name: str) -> None: try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( @@ -211,7 +166,6 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") - async def index_done_callback(self): - # Protect file write operation + async def index_done_callback(self) -> None: async with self._save_lock: self._client.save() diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 15525375..5ffbf2bc 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,20 +3,11 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, Union, Tuple, List, Dict -import pipmaster as pm +from typing import Any, List, Dict, final +import numpy as np import configparser -if not pm.is_installed("neo4j"): - pm.install("neo4j") -from neo4j import ( - AsyncGraphDatabase, - exceptions as neo4jExceptions, - AsyncDriver, - AsyncManagedTransaction, - GraphDatabase, -) from tenacity import ( retry, stop_after_attempt, @@ -27,12 +18,29 @@ from tenacity import ( from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +import pipmaster as pm +if not pm.is_installed("neo4j"): + pm.install("neo4j") + +try: + from neo4j import ( + AsyncGraphDatabase, + exceptions as neo4jExceptions, + AsyncDriver, + AsyncManagedTransaction, + GraphDatabase, + ) +except ImportError as e: + raise ImportError( + "`neo4j` library is not installed. Please install it via pip: `pip install neo4j`." + ) from e config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class Neo4JStorage(BaseGraphStorage): @staticmethod @@ -140,8 +148,9 @@ class Neo4JStorage(BaseGraphStorage): if self._driver: await self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + # Noe4J handles persistence automatically + pass async def _label_exists(self, label: str) -> bool: """Check if a label exists in the Neo4j database.""" @@ -191,7 +200,7 @@ class Neo4JStorage(BaseGraphStorage): ) return single_result["edgeExists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier. Args: @@ -252,17 +261,7 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """Find edge between two nodes identified by their labels. - - Args: - source_node_id (str): Label of the source node - target_node_id (str): Label of the target node - - Returns: - dict: Edge properties if found, with at least {"weight": 0.0} - None: If error occurs - """ + ) -> dict[str, str] | None: try: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -321,7 +320,7 @@ class Neo4JStorage(BaseGraphStorage): # Return default edge properties on error return {"weight": 0.0, "source_id": None, "target_id": None} - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: node_label = source_node_id.strip('"') """ @@ -364,7 +363,7 @@ class Neo4JStorage(BaseGraphStorage): ) ), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the Neo4j database. @@ -405,8 +404,8 @@ class Neo4JStorage(BaseGraphStorage): ), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -603,7 +602,7 @@ class Neo4JStorage(BaseGraphStorage): await traverse(label, 0) return result - async def get_all_labels(self) -> List[str]: + async def get_all_labels(self) -> list[str]: """ Get all existing node labels in the database Returns: @@ -627,3 +626,11 @@ class Neo4JStorage(BaseGraphStorage): async for record in result: labels.append(record["label"]) return labels + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index bb84cf82..04bb3bd7 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -1,61 +1,12 @@ -""" -NetworkX Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - import html import os from dataclasses import dataclass -from typing import Any, Union, cast -import networkx as nx +from typing import Any, cast, final + import numpy as np +from lightrag.types import KnowledgeGraph from lightrag.utils import ( logger, ) @@ -64,7 +15,15 @@ from lightrag.base import ( BaseGraphStorage, ) +try: + import networkx as nx +except ImportError as e: + raise ImportError( + "`networkx` library is not installed. Please install it via pip: `pip install networkx`." + ) from e + +@final @dataclass class NetworkXStorage(BaseGraphStorage): @staticmethod @@ -142,7 +101,7 @@ class NetworkXStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } - async def index_done_callback(self): + async def index_done_callback(self) -> None: NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: @@ -151,7 +110,7 @@ class NetworkXStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: return self._graph.has_edge(source_node_id, target_node_id) - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: return self._graph.nodes.get(node_id) async def node_degree(self, node_id: str) -> int: @@ -162,35 +121,32 @@ class NetworkXStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> dict[str, str] | None: return self._graph.edges.get((source_node_id, target_node_id)) - async def get_node_edges(self, source_node_id: str): + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: if self._graph.has_node(source_node_id): return list(self._graph.edges(source_node_id)) return None - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: self._graph.add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def delete_node(self, node_id: str): - """ - Delete a node from the graph based on the specified node_id. - - :param node_id: The node_id to delete - """ + async def delete_node(self, node_id: str) -> None: if self._graph.has_node(node_id): self._graph.remove_node(node_id) logger.info(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -226,3 +182,11 @@ class NetworkXStorage(BaseGraphStorage): for source, target in edges: if self._graph.has_edge(source, target): self._graph.remove_edge(source, target) + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 65f1060c..c9d8d1b5 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -4,16 +4,11 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Union, final import numpy as np -import pipmaster as pm -if not pm.is_installed("oracledb"): - pm.install("oracledb") - - -import oracledb +from lightrag.types import KnowledgeGraph from ..base import ( BaseGraphStorage, @@ -23,6 +18,19 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger +import pipmaster as pm + +if not pm.is_installed("oracledb"): + pm.install("oracledb") + +try: + import oracledb + +except ImportError as e: + raise ImportError( + "`oracledb` library is not installed. Please install it via pip: `pip install oracledb`." + ) from e + class OracleDB: def __init__(self, config, **kwargs): @@ -169,6 +177,7 @@ class OracleDB: raise +@final @dataclass class OracleKVStorage(BaseKVStorage): # db instance must be injected before use @@ -181,7 +190,7 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -232,7 +241,7 @@ class OracleKVStorage(BaseKVStorage): res = [{k: v} for k, v in dict_res.items()] return res - async def filter_keys(self, keys: list[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), @@ -248,7 +257,7 @@ class OracleKVStorage(BaseKVStorage): return set(keys) ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): list_data = [ { @@ -307,20 +316,17 @@ class OracleKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) - async def index_done_callback(self): - if is_namespace( - self.namespace, - (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), - ): - logger.info("full doc and chunk data had been saved into oracle db!") + async def index_done_callback(self) -> None: + # Oracle handles persistence automatically + pass + + async def drop(self) -> None: + raise NotImplementedError +@final @dataclass class OracleVectorDBStorage(BaseVectorStorage): - # db instance must be injected before use - # db: OracleDB - cosine_better_than_threshold: float = None - def __post_init__(self): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") @@ -330,16 +336,8 @@ class OracleVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold - async def upsert(self, data: dict[str, dict]): - """向向量数据库中插入数据""" - pass - - async def index_done_callback(self): - pass - #################### query method ############### - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] # 转换精度 @@ -359,21 +357,29 @@ class OracleVectorDBStorage(BaseVectorStorage): # print("vector search result:",results) return results + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + raise NotImplementedError + async def index_done_callback(self) -> None: + # Oracles handles persistence automatically + pass + + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError + + +@final @dataclass class OracleGraphStorage(BaseGraphStorage): - # db instance must be injected before use - # db: OracleDB - def __post_init__(self): - """从graphml文件加载图""" self._max_batch_size = self.global_config.get("embedding_batch_num", 10) #################### insert method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]): - """插入或更新节点""" - # print("go into upsert node method") + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] @@ -406,7 +412,7 @@ class OracleGraphStorage(BaseGraphStorage): async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: """插入或更新边""" # print("go into upsert edge method") source_name = source_node_id @@ -446,8 +452,9 @@ class OracleGraphStorage(BaseGraphStorage): await self.db.execute(merge_sql, data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: - """为节点生成向量""" + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -464,11 +471,9 @@ class OracleGraphStorage(BaseGraphStorage): nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] return embeddings, nodes_ids - async def index_done_callback(self): - """写入graphhml图文件""" - logger.info( - "Node and edge data had been saved into oracle db already, so nothing to do here!" - ) + async def index_done_callback(self) -> None: + # Oracles handles persistence automatically + pass #################### query method ################# async def has_node(self, node_id: str) -> bool: @@ -486,7 +491,6 @@ class OracleGraphStorage(BaseGraphStorage): return False async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - """根据源和目标节点id检查边是否存在""" SQL = SQL_TEMPLATES["has_edge"] params = { "workspace": self.db.workspace, @@ -503,7 +507,6 @@ class OracleGraphStorage(BaseGraphStorage): return False async def node_degree(self, node_id: str) -> int: - """根据节点id获取节点的度""" SQL = SQL_TEMPLATES["node_degree"] params = {"workspace": self.db.workspace, "node_id": node_id} # print(SQL) @@ -521,7 +524,7 @@ class OracleGraphStorage(BaseGraphStorage): # print("Edge degree",degree) return degree - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: """根据节点id获取节点数据""" SQL = SQL_TEMPLATES["get_node"] params = {"workspace": self.db.workspace, "node_id": node_id} @@ -537,8 +540,7 @@ class OracleGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """根据源和目标节点id获取边""" + ) -> dict[str, str] | None: SQL = SQL_TEMPLATES["get_edge"] params = { "workspace": self.db.workspace, @@ -553,8 +555,7 @@ class OracleGraphStorage(BaseGraphStorage): # print("Edge not exist!",self.db.workspace, source_node_id, target_node_id) return None - async def get_node_edges(self, source_node_id: str): - """根据节点id获取节点的所有边""" + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: if await self.has_node(source_node_id): SQL = SQL_TEMPLATES["get_node_edges"] params = {"workspace": self.db.workspace, "source_node_id": source_node_id} @@ -590,6 +591,17 @@ class OracleGraphStorage(BaseGraphStorage): if res: return res + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 51b25385..5f845894 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,24 +4,19 @@ import json import os import time from dataclasses import dataclass -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Dict, List, Union, final import numpy as np -import pipmaster as pm -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") +from lightrag.types import KnowledgeGraph import sys - -import asyncpg from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) -from tqdm.asyncio import tqdm as tqdm_async from ..base import ( BaseGraphStorage, @@ -39,6 +34,20 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +import pipmaster as pm + +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + +try: + import asyncpg + from tqdm.asyncio import tqdm as tqdm_async + +except ImportError as e: + raise ImportError( + "`asyncpg` library is not installed. Please install it via pip: `pip install asyncpg`." + ) from e + class PostgreSQLDB: def __init__(self, config, **kwargs): @@ -175,6 +184,7 @@ class PostgreSQLDB: pass +@final @dataclass class PGKVStorage(BaseKVStorage): # db instance must be injected before use @@ -185,7 +195,7 @@ class PGKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get doc_full data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -240,7 +250,7 @@ class PGKVStorage(BaseKVStorage): params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) - async def filter_keys(self, keys: List[str]) -> Set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), @@ -261,7 +271,7 @@ class PGKVStorage(BaseKVStorage): print(params) ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): pass elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): @@ -287,20 +297,17 @@ class PGKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) - async def index_done_callback(self): - if is_namespace( - self.namespace, - (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), - ): - logger.info("full doc and chunk data had been saved into postgresql db!") + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass + + async def drop(self) -> None: + raise NotImplementedError +@final @dataclass class PGVectorStorage(BaseVectorStorage): - # db instance must be injected before use - # db: PostgreSQLDB - cosine_better_than_threshold: float = None - def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] config = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -352,7 +359,7 @@ class PGVectorStorage(BaseVectorStorage): } return upsert_sql, data - async def upsert(self, data: Dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -398,12 +405,8 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) - async def index_done_callback(self): - logger.info("vector data had been saved into postgresql db!") - #################### query method ############### - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) @@ -417,23 +420,31 @@ class PGVectorStorage(BaseVectorStorage): results = await self.db.query(sql, params=params, multirows=True) return results + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError + + +@final @dataclass class PGDocStatusStorage(DocStatusStorage): - # db instance must be injected before use - # db: PostgreSQLDB - - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that don't exist in storage""" - keys = ",".join([f"'{_id}'" for _id in data]) + keys = ",".join([f"'{_id}'" for _id in keys]) sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" result = await self.db.query(sql, multirows=True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: - return set(data) + return set(keys) else: existed = set([element["id"] for element in result]) - return set(data) - existed + return set(keys) - existed async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" @@ -452,6 +463,9 @@ class PGDocStatusStorage(DocStatusStorage): updated_at=result[0]["updated_at"], ) + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + raise NotImplementedError + async def get_status_counts(self) -> Dict[str, int]: """Get counts of documents in each status""" sql = """SELECT status as "status", COUNT(1) as "count" @@ -470,7 +484,7 @@ class PGDocStatusStorage(DocStatusStorage): ) -> Dict[str, DocProcessingStatus]: """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" - params = {"workspace": self.db.workspace, "status": status} + params = {"workspace": self.db.workspace, "status": status.value} result = await self.db.query(sql, params, True) return { element["id"]: DocProcessingStatus( @@ -485,11 +499,11 @@ class PGDocStatusStorage(DocStatusStorage): for element in result } - async def index_done_callback(self): - """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" - logger.info("Doc status had been saved into postgresql db!") + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Update or insert document status Args: @@ -520,31 +534,8 @@ class PGDocStatusStorage(DocStatusStorage): ) return data - async def update_doc_status(self, data: dict[str, dict]) -> None: - """ - Updates only the document status, chunk count, and updated timestamp. - - This method ensures that only relevant fields are updated instead of overwriting - the entire document record. If `updated_at` is not provided, the database will - automatically use the current timestamp. - """ - sql = """ - UPDATE LIGHTRAG_DOC_STATUS - SET status = $3, - chunks_count = $4, - updated_at = CURRENT_TIMESTAMP - WHERE workspace = $1 AND id = $2 - """ - for k, v in data.items(): - _data = { - "workspace": self.db.workspace, - "id": k, - "status": v["status"].value, # Convert Enum to string - "chunks_count": v.get( - "chunks_count", -1 - ), # Default to -1 if not provided - } - await self.db.execute(sql, _data) + async def drop(self) -> None: + raise NotImplementedError class PGGraphQueryException(Exception): @@ -565,11 +556,9 @@ class PGGraphQueryException(Exception): return self.details +@final @dataclass class PGGraphStorage(BaseGraphStorage): - # db instance must be injected before use - # db: PostgreSQLDB - @staticmethod def load_nx_graph(file_name): print("no preloading of graph with AGE in production") @@ -580,8 +569,9 @@ class PGGraphStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass @staticmethod def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: @@ -811,7 +801,7 @@ class PGGraphStorage(BaseGraphStorage): ) return single_result["edge_exists"] - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: label = PGGraphStorage._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (n:Entity {node_id: "%s"}) @@ -866,17 +856,7 @@ class PGGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: - """ - Find all edges between nodes of two given labels - - Args: - source_node_id (str): Label of the source nodes - target_node_id (str): Label of the target nodes - - Returns: - list: List of all relationships/edges found - """ + ) -> dict[str, str] | None: src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) @@ -900,7 +880,7 @@ class PGGraphStorage(BaseGraphStorage): ) return result - async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information @@ -948,14 +928,7 @@ class PGGraphStorage(BaseGraphStorage): wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((PGGraphQueryException,)), ) - async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): - """ - Upsert a node in the AGE database. - - Args: - node_id: The unique identifier for the node (used as label) - node_data: Dictionary of node properties - """ + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: label = PGGraphStorage._encode_graph_label(node_id.strip('"')) properties = node_data @@ -986,8 +959,8 @@ class PGGraphStorage(BaseGraphStorage): retry=retry_if_exception_type((PGGraphQueryException,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] - ): + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. @@ -1029,6 +1002,22 @@ class PGGraphStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 3af76328..0610346f 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -1,5 +1,6 @@ import asyncio import os +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -7,16 +8,24 @@ import hashlib import uuid from ..utils import logger from ..base import BaseVectorStorage -import pipmaster as pm import configparser + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + +import pipmaster as pm + if not pm.is_installed("qdrant_client"): pm.install("qdrant_client") -from qdrant_client import QdrantClient, models +try: + from qdrant_client import QdrantClient, models -config = configparser.ConfigParser() -config.read("config.ini", "utf-8") +except ImportError: + raise ImportError( + "`qdrant_client` library is not installed. Please install it via pip: `pip install qdrant-client`." + ) def compute_mdhash_id_for_qdrant( @@ -47,10 +56,9 @@ def compute_mdhash_id_for_qdrant( raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.") +@final @dataclass class QdrantVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - @staticmethod def create_collection_if_not_exist( client: QdrantClient, collection_name: str, **kwargs @@ -85,7 +93,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ), ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not len(data): logger.warning("You insert an empty data to vector DB") return [] @@ -130,7 +138,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) return results - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, @@ -143,3 +151,13 @@ class QdrantVectorDBStorage(BaseVectorStorage): logger.debug(f"query result: {results}") return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] + + async def index_done_callback(self) -> None: + # Qdrant handles persistence automatically + pass + + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index ed8f46f9..2d5c94ce 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,5 +1,5 @@ import os -from typing import Any, Union +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -19,6 +19,7 @@ config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self): @@ -28,7 +29,7 @@ class RedisKVStorage(BaseKVStorage): self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None @@ -39,16 +40,16 @@ class RedisKVStorage(BaseKVStorage): results = await pipe.execute() return [json.loads(result) if result else None for result in results] - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: pipe = self._redis.pipeline() - for key in data: + for key in keys: pipe.exists(f"{self.namespace}:{key}") results = await pipe.execute() - existing_ids = {data[i] for i, exists in enumerate(results) if exists} - return set(data) - existing_ids + existing_ids = {keys[i] for i, exists in enumerate(results) if exists} + return set(keys) - existing_ids - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pipe = self._redis.pipeline() for k, v in tqdm_async(data.items(), desc="Upserting"): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) @@ -61,3 +62,7 @@ class RedisKVStorage(BaseKVStorage): keys = await self._redis.keys(f"{self.namespace}:*") if keys: await self._redis.delete(*keys) + + async def index_done_callback(self) -> None: + # Redis handles persistence automatically + pass diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 00b8003d..6dbfb934 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -1,9 +1,18 @@ import asyncio import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Union, final import numpy as np + +from lightrag.types import KnowledgeGraph + +from tqdm import tqdm + +from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage +from ..namespace import NameSpace, is_namespace +from ..utils import logger + import pipmaster as pm if not pm.is_installed("pymysql"): @@ -11,12 +20,13 @@ if not pm.is_installed("pymysql"): if not pm.is_installed("sqlalchemy"): pm.install("sqlalchemy") -from sqlalchemy import create_engine, text -from tqdm import tqdm +try: + from sqlalchemy import create_engine, text -from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage -from ..namespace import NameSpace, is_namespace -from ..utils import logger +except ImportError as e: + raise ImportError( + "`pymysql, sqlalchemy` library is not installed. Please install it via pip: `pip install pymysql sqlalchemy`." + ) from e class TiDB: @@ -99,6 +109,7 @@ class TiDB: raise +@final @dataclass class TiDBKVStorage(BaseKVStorage): # db instance must be injected before use @@ -110,7 +121,7 @@ class TiDBKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + async def get_by_id(self, id: str) -> dict[str, Any] | None: """Fetch doc_full data by id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"id": id} @@ -125,8 +136,7 @@ class TiDBKVStorage(BaseKVStorage): ) return await self.db.query(SQL, multirows=True) - async def filter_keys(self, keys: list[str]) -> set[str]: - """过滤掉重复内容""" + async def filter_keys(self, keys: set[str]) -> set[str]: SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), id_field=namespace_to_id(self.namespace), @@ -147,7 +157,7 @@ class TiDBKVStorage(BaseKVStorage): return data ################ INSERT full_doc AND chunks ################ - async def upsert(self, data: dict[str, Any]) -> None: + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): @@ -200,20 +210,17 @@ class TiDBKVStorage(BaseKVStorage): await self.db.execute(merge_sql, data) return left_data - async def index_done_callback(self): - if is_namespace( - self.namespace, - (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), - ): - logger.info("full doc and chunk data had been saved into TiDB db!") + async def index_done_callback(self) -> None: + # Ti handles persistence automatically + pass + + async def drop(self) -> None: + raise NotImplementedError +@final @dataclass class TiDBVectorDBStorage(BaseVectorStorage): - # db instance must be injected before use - # db: TiDB - cosine_better_than_threshold: float = None - def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" @@ -227,7 +234,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold - async def query(self, query: str, top_k: int) -> list[dict]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Search from tidb vector""" embeddings = await self.embedding_func([query]) embedding = embeddings[0] @@ -249,7 +256,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): return results ###### INSERT entities And relationships ###### - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: # ignore, upsert in TiDBKVStorage already if not len(data): logger.warning("You insert an empty data to vector DB") @@ -332,7 +339,18 @@ class TiDBVectorDBStorage(BaseVectorStorage): params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) + async def delete_entity(self, entity_name: str) -> None: + raise NotImplementedError + async def delete_entity_relation(self, entity_name: str) -> None: + raise NotImplementedError + + async def index_done_callback(self) -> None: + # Ti handles persistence automatically + pass + + +@final @dataclass class TiDBGraphStorage(BaseGraphStorage): # db instance must be injected before use @@ -342,7 +360,7 @@ class TiDBGraphStorage(BaseGraphStorage): self._max_batch_size = self.global_config["embedding_batch_num"] #################### upsert method ################ - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] @@ -373,7 +391,7 @@ class TiDBGraphStorage(BaseGraphStorage): async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + ) -> None: source_name = source_node_id target_name = target_node_id weight = edge_data["weight"] @@ -409,7 +427,9 @@ class TiDBGraphStorage(BaseGraphStorage): } await self.db.execute(merge_sql, data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -442,14 +462,14 @@ class TiDBGraphStorage(BaseGraphStorage): degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) return degree - async def get_node(self, node_id: str) -> Union[dict, None]: + async def get_node(self, node_id: str) -> dict[str, str] | None: sql = SQL_TEMPLATES["get_node"] param = {"name": node_id, "workspace": self.db.workspace} return await self.db.query(sql, param) async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> dict[str, str] | None: sql = SQL_TEMPLATES["get_edge"] param = { "source_name": source_node_id, @@ -458,9 +478,7 @@ class TiDBGraphStorage(BaseGraphStorage): } return await self.db.query(sql, param) - async def get_node_edges( - self, source_node_id: str - ) -> Union[list[tuple[str, str]], None]: + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: sql = SQL_TEMPLATES["get_node_edges"] param = {"source_name": source_node_id, "workspace": self.db.workspace} res = await self.db.query(sql, param, multirows=True) @@ -470,6 +488,21 @@ class TiDBGraphStorage(BaseGraphStorage): else: return [] + async def index_done_callback(self) -> None: + # Ti handles persistence automatically + pass + + async def delete_node(self, node_id: str) -> None: + raise NotImplementedError + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8fbf4d29..6220406b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -674,7 +674,7 @@ class LightRAG: "content": content, "content_summary": self._get_content_summary(content), "content_length": len(content), - "status": DocStatus.PENDING, + "status": DocStatus.PENDING.value, "created_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(), } @@ -745,7 +745,7 @@ class LightRAG: await self.doc_status.upsert( { doc_status_id: { - "status": DocStatus.PROCESSING, + "status": DocStatus.PROCESSING.value, "updated_at": datetime.now().isoformat(), "content": status_doc.content, "content_summary": status_doc.content_summary, @@ -779,10 +779,10 @@ class LightRAG: ] try: await asyncio.gather(*tasks) - await self.doc_status.update_doc_status( + await self.doc_status.upsert( { doc_status_id: { - "status": DocStatus.PROCESSED, + "status": DocStatus.PROCESSED.value, "chunks_count": len(chunks), "content": status_doc.content, "content_summary": status_doc.content_summary, @@ -796,10 +796,10 @@ class LightRAG: except Exception as e: logger.error(f"Failed to process document {doc_id}: {str(e)}") - await self.doc_status.update_doc_status( + await self.doc_status.upsert( { doc_status_id: { - "status": DocStatus.FAILED, + "status": DocStatus.FAILED.value, "error": str(e), "content": status_doc.content, "content_summary": status_doc.content_summary,