diff --git a/lightrag/base.py b/lightrag/base.py index fc4702d4..798e3176 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -92,6 +92,7 @@ class StorageNameSpace(ABC): @dataclass class BaseVectorStorage(StorageNameSpace, ABC): embedding_func: EmbeddingFunc + cosine_better_than_threshold: float meta_fields: set[str] = field(default_factory=set) @abstractmethod diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 37ab57d7..24f70de9 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -5,21 +5,11 @@ import os import sys from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union, final import numpy as np -import pipmaster as pm - -if not pm.is_installed("psycopg-pool"): - pm.install("psycopg-pool") - pm.install("psycopg[binary,pool]") -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - from lightrag.types import KnowledgeGraph -import psycopg -from psycopg.rows import namedtuple_row -from psycopg_pool import AsyncConnectionPool, PoolTimeout + from tenacity import ( retry, retry_if_exception_type, @@ -37,6 +27,16 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +try: + import psycopg + from psycopg.rows import namedtuple_row + from psycopg_pool import AsyncConnectionPool, PoolTimeout +except ImportError as e: + raise ImportError( + "psycopg-pool, psycopg[binary,pool], asyncpg library is not installed. Please install it to proceed." + ) from e + + class AGEQueryException(Exception): """Exception for the AGE queries.""" @@ -55,6 +55,7 @@ class AGEQueryException(Exception): return self.details +@final @dataclass class AGEStorage(BaseGraphStorage): @staticmethod @@ -100,9 +101,6 @@ class AGEStorage(BaseGraphStorage): if self._driver: await self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") - @staticmethod def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: """ @@ -627,3 +625,6 @@ class AGEStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: raise NotImplementedError + + async def index_done_callback(self) -> None: + pass diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 7e325abd..f2d2293f 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,19 +1,25 @@ import asyncio from dataclasses import dataclass -from typing import Any +from typing import Any, final import numpy as np -from chromadb import HttpClient, PersistentClient -from chromadb.config import Settings + from lightrag.base import BaseVectorStorage from lightrag.utils import logger +try: + from chromadb import HttpClient, PersistentClient + from chromadb.config import Settings +except ImportError as e: + raise ImportError( + "chromadb library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class ChromaVectorDBStorage(BaseVectorStorage): """ChromaDB vector storage implementation.""" - cosine_better_than_threshold: float = None - def __post_init__(self): try: config = self.global_config.get("vector_db_storage_cls_kwargs", {}) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 3027f3f0..e2c06afe 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -1,8 +1,8 @@ import os import time import asyncio -from typing import Any -import faiss +from typing import Any, final + import json import numpy as np from tqdm.asyncio import tqdm as tqdm_async @@ -16,7 +16,15 @@ from lightrag.base import ( BaseVectorStorage, ) +try: + import faiss +except ImportError as e: + raise ImportError( + "faiss library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class FaissVectorDBStorage(BaseVectorStorage): """ @@ -24,8 +32,6 @@ class FaissVectorDBStorage(BaseVectorStorage): Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. """ - cosine_better_than_threshold: float = None - def __post_init__(self): # Grab config values if available config = self.global_config.get("vector_db_storage_cls_kwargs", {}) diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 48bf77c8..4038be23 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -3,13 +3,11 @@ import inspect import json import os from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, final import numpy as np -from gremlin_python.driver import client, serializer -from gremlin_python.driver.aiohttp.transport import AiohttpTransport -from gremlin_python.driver.protocol import GremlinServerError + from tenacity import ( retry, retry_if_exception_type, @@ -22,7 +20,17 @@ from lightrag.utils import logger from ..base import BaseGraphStorage +try: + from gremlin_python.driver import client, serializer + from gremlin_python.driver.aiohttp.transport import AiohttpTransport + from gremlin_python.driver.protocol import GremlinServerError +except ImportError as e: + raise ImportError( + "gremlin library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class GremlinStorage(BaseGraphStorage): @staticmethod @@ -79,8 +87,8 @@ class GremlinStorage(BaseGraphStorage): if self._driver: self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + pass @staticmethod def _to_value_map(value: Any) -> str: diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index fad03acc..b96a744c 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -1,56 +1,6 @@ -""" -JsonDocStatus Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - from dataclasses import dataclass import os -from typing import Any, Union +from typing import Any, Union, final from lightrag.base import ( DocProcessingStatus, @@ -64,6 +14,7 @@ from lightrag.utils import ( ) +@final @dataclass class JsonDocStatusStorage(DocStatusStorage): """JSON implementation of document status storage""" @@ -74,9 +25,9 @@ class JsonDocStatusStorage(DocStatusStorage): self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Loaded document status storage with {len(self._data)} records") - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - return set(data) - set(self._data.keys()) + return set(keys) - set(self._data.keys()) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: result: list[dict[str, Any]] = [] @@ -94,7 +45,6 @@ class JsonDocStatusStorage(DocStatusStorage): return counts async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() @@ -102,7 +52,6 @@ class JsonDocStatusStorage(DocStatusStorage): } async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() @@ -110,7 +59,6 @@ class JsonDocStatusStorage(DocStatusStorage): } async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processed documents""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() @@ -118,23 +66,16 @@ class JsonDocStatusStorage(DocStatusStorage): } async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() if v["status"] == DocStatus.PROCESSING } - async def index_done_callback(self): - """Save data to file after indexing""" + async def index_done_callback(self) -> None: write_json(self._data, self._file_name) async def upsert(self, data: dict[str, Any]) -> None: - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ self._data.update(data) await self.index_done_callback() @@ -142,7 +83,12 @@ class JsonDocStatusStorage(DocStatusStorage): return self._data.get(id) async def delete(self, doc_ids: list[str]): - """Delete document status by IDs""" for doc_id in doc_ids: self._data.pop(doc_id, None) await self.index_done_callback() + + async def drop(self) -> None: + raise NotImplementedError + + async def update_doc_status(self, data: dict[str, Any]) -> None: + raise NotImplementedError diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 7d51ae93..779c52a9 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -22,7 +22,7 @@ class JsonKVStorage(BaseKVStorage): self._lock = asyncio.Lock() logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - async def index_done_callback(self): + async def index_done_callback(self) -> None: write_json(self._data, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 703229c8..1288df07 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1,27 +1,29 @@ import asyncio import os -from typing import Any +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np from lightrag.utils import logger from ..base import BaseVectorStorage -import pipmaster as pm + import configparser -if not pm.is_installed("pymilvus"): - pm.install("pymilvus") -from pymilvus import MilvusClient +try: + from pymilvus import MilvusClient +except ImportError: + raise ImportError( + "pymilvus library is not installed. Please install it to proceed." + ) config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class MilvusVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - @staticmethod def create_collection_if_not_exist( client: MilvusClient, collection_name: str, **kwargs diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 463e24d2..f44332bf 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,22 +1,11 @@ import os from dataclasses import dataclass import numpy as np -import pipmaster as pm import configparser from tqdm.asyncio import tqdm as tqdm_async import asyncio -if not pm.is_installed("pymongo"): - pm.install("pymongo") - -if not pm.is_installed("motor"): - pm.install("motor") - -from typing import Any, List, Union -from motor.motor_asyncio import AsyncIOMotorClient -from pymongo import MongoClient -from pymongo.operations import SearchIndexModel -from pymongo.errors import PyMongoError +from typing import Any, List, Union, final from ..base import ( BaseGraphStorage, @@ -30,11 +19,22 @@ from ..namespace import NameSpace, is_namespace from ..utils import logger from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +try: + from motor.motor_asyncio import AsyncIOMotorClient + from pymongo import MongoClient + from pymongo.operations import SearchIndexModel + from pymongo.errors import PyMongoError +except ImportError as e: + raise ImportError( + "motor, pymongo library is not installed. Please install it to proceed." + ) from e + config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): @@ -115,6 +115,7 @@ class MongoKVStorage(BaseKVStorage): await self._data.drop() +@final @dataclass class MongoDocStatusStorage(DocStatusStorage): def __post_init__(self): @@ -210,7 +211,15 @@ class MongoDocStatusStorage(DocStatusStorage): """Get all procesed documents""" return await self.get_docs_by_status(DocStatus.PROCESSED) + async def index_done_callback(self) -> None: + # Implement the method here + pass + async def update_doc_status(self, data: dict[str, Any]) -> None: + raise NotImplementedError + + +@final @dataclass class MongoGraphStorage(BaseGraphStorage): """ @@ -774,11 +783,13 @@ class MongoGraphStorage(BaseGraphStorage): return result + async def index_done_callback(self) -> None: + pass + +@final @dataclass class MongoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 8b931424..4ab98fe6 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -1,65 +1,10 @@ -""" -NanoVectorDB Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - import asyncio import os -from typing import Any +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np -import pipmaster as pm -if not pm.is_installed("nano-vectordb"): - pm.install("nano-vectordb") - -from nano_vectordb import NanoVectorDB import time from lightrag.utils import ( @@ -71,11 +16,17 @@ from lightrag.base import ( BaseVectorStorage, ) +try: + from nano_vectordb import NanoVectorDB +except ImportError as e: + raise ImportError( + "nano-vectordb library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - def __post_init__(self): # Initialize lock only for file operations self._save_lock = asyncio.Lock() diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index d8e8faa8..8d078af0 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,21 +3,11 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, List, Dict +from typing import Any, List, Dict, final import numpy as np -import pipmaster as pm import configparser -if not pm.is_installed("neo4j"): - pm.install("neo4j") -from neo4j import ( - AsyncGraphDatabase, - exceptions as neo4jExceptions, - AsyncDriver, - AsyncManagedTransaction, - GraphDatabase, -) from tenacity import ( retry, stop_after_attempt, @@ -29,11 +19,25 @@ from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +try: + from neo4j import ( + AsyncGraphDatabase, + exceptions as neo4jExceptions, + AsyncDriver, + AsyncManagedTransaction, + GraphDatabase, + ) +except ImportError as e: + raise ImportError( + "neo4j library is not installed. Please install it to proceed." + ) from e + config = configparser.ConfigParser() config.read("config.ini", "utf-8") +@final @dataclass class Neo4JStorage(BaseGraphStorage): @staticmethod @@ -141,8 +145,8 @@ class Neo4JStorage(BaseGraphStorage): if self._driver: await self._driver.close() - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + pass async def _label_exists(self, label: str) -> bool: """Check if a label exists in the Neo4j database.""" diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 109c5827..f98a8bbb 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -1,58 +1,8 @@ -""" -NetworkX Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - import html import os from dataclasses import dataclass -from typing import Any, cast -import networkx as nx +from typing import Any, cast, final + import numpy as np @@ -65,7 +15,15 @@ from lightrag.base import ( BaseGraphStorage, ) +try: + import networkx as nx +except ImportError as e: + raise ImportError( + "networkx library is not installed. Please install it to proceed." + ) from e + +@final @dataclass class NetworkXStorage(BaseGraphStorage): @staticmethod diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 74268a67..aec4ada4 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -4,17 +4,11 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Union, final import numpy as np -import pipmaster as pm - -if not pm.is_installed("oracledb"): - pm.install("oracledb") - from lightrag.types import KnowledgeGraph -import oracledb from ..base import ( BaseGraphStorage, @@ -24,6 +18,14 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger +try: + import oracledb + +except ImportError as e: + raise ImportError( + "oracledb library is not installed. Please install it to proceed." + ) from e + class OracleDB: def __init__(self, config, **kwargs): @@ -170,6 +172,7 @@ class OracleDB: raise +@final @dataclass class OracleKVStorage(BaseKVStorage): # db instance must be injected before use @@ -319,12 +322,9 @@ class OracleKVStorage(BaseKVStorage): raise NotImplementedError +@final @dataclass class OracleVectorDBStorage(BaseVectorStorage): - # db instance must be injected before use - # db: OracleDB - cosine_better_than_threshold: float = None - def __post_init__(self): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") @@ -337,7 +337,7 @@ class OracleVectorDBStorage(BaseVectorStorage): async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pass - async def index_done_callback(self): + async def index_done_callback(self) -> None: pass #################### query method ############### @@ -370,13 +370,10 @@ class OracleVectorDBStorage(BaseVectorStorage): raise NotImplementedError +@final @dataclass class OracleGraphStorage(BaseGraphStorage): - # db instance must be injected before use - # db: OracleDB - def __post_init__(self): - """从graphml文件加载图""" self._max_batch_size = self.global_config.get("embedding_batch_num", 10) #################### insert method ################ @@ -474,10 +471,7 @@ class OracleGraphStorage(BaseGraphStorage): return embeddings, nodes_ids async def index_done_callback(self) -> None: - """写入graphhml图文件""" - logger.info( - "Node and edge data had been saved into oracle db already, so nothing to do here!" - ) + pass #################### query method ################# async def has_node(self, node_id: str) -> bool: diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 77a42ad1..c63547ce 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,26 +4,19 @@ import json import os import time from dataclasses import dataclass -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, final import numpy as np -import pipmaster as pm from lightrag.types import KnowledgeGraph -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - import sys - -import asyncpg from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) -from tqdm.asyncio import tqdm as tqdm_async from ..base import ( BaseGraphStorage, @@ -41,6 +34,15 @@ if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +try: + import asyncpg + from tqdm.asyncio import tqdm as tqdm_async + +except ImportError as e: + raise ImportError( + "asyncpg, tqdm_async library is not installed. Please install it to proceed." + ) from e + class PostgreSQLDB: def __init__(self, config, **kwargs): @@ -177,6 +179,7 @@ class PostgreSQLDB: pass +@final @dataclass class PGKVStorage(BaseKVStorage): # db instance must be injected before use @@ -290,22 +293,15 @@ class PGKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) async def index_done_callback(self) -> None: - if is_namespace( - self.namespace, - (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), - ): - logger.info("full doc and chunk data had been saved into postgresql db!") + pass async def drop(self) -> None: raise NotImplementedError +@final @dataclass class PGVectorStorage(BaseVectorStorage): - # db instance must be injected before use - # db: PostgreSQLDB - cosine_better_than_threshold: float = None - def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] config = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -404,7 +400,7 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) async def index_done_callback(self) -> None: - logger.info("vector data had been saved into postgresql db!") + pass #################### query method ############### async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: @@ -430,22 +426,23 @@ class PGVectorStorage(BaseVectorStorage): raise NotImplementedError +@final @dataclass class PGDocStatusStorage(DocStatusStorage): # db instance must be injected before use # db: PostgreSQLDB - async def filter_keys(self, data: set[str]) -> set[str]: + async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that don't exist in storage""" - keys = ",".join([f"'{_id}'" for _id in data]) + keys = ",".join([f"'{_id}'" for _id in keys]) sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" result = await self.db.query(sql, multirows=True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: - return set(data) + return set(keys) else: existed = set([element["id"] for element in result]) - return set(data) - existed + return set(keys) - existed async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" @@ -464,6 +461,9 @@ class PGDocStatusStorage(DocStatusStorage): updated_at=result[0]["updated_at"], ) + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + raise NotImplementedError + async def get_status_counts(self) -> Dict[str, int]: """Get counts of documents in each status""" sql = """SELECT status as "status", COUNT(1) as "count" @@ -513,9 +513,8 @@ class PGDocStatusStorage(DocStatusStorage): """Get all procesed documents""" return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self): - """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" - logger.info("Doc status had been saved into postgresql db!") + async def index_done_callback(self) -> None: + pass async def upsert(self, data: dict[str, dict]): """Update or insert document status @@ -574,6 +573,9 @@ class PGDocStatusStorage(DocStatusStorage): } await self.db.execute(sql, _data) + async def drop(self) -> None: + raise NotImplementedError + class PGGraphQueryException(Exception): """Exception for the AGE queries.""" @@ -593,11 +595,9 @@ class PGGraphQueryException(Exception): return self.details +@final @dataclass class PGGraphStorage(BaseGraphStorage): - # db instance must be injected before use - # db: PostgreSQLDB - @staticmethod def load_nx_graph(file_name): print("no preloading of graph with AGE in production") @@ -608,8 +608,8 @@ class PGGraphStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } - async def index_done_callback(self): - print("KG successfully indexed.") + async def index_done_callback(self) -> None: + pass @staticmethod def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index eb9582e6..1d4a0ca1 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Any +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -8,17 +8,20 @@ import hashlib import uuid from ..utils import logger from ..base import BaseVectorStorage -import pipmaster as pm import configparser -if not pm.is_installed("qdrant_client"): - pm.install("qdrant_client") - -from qdrant_client import QdrantClient, models config = configparser.ConfigParser() config.read("config.ini", "utf-8") +try: + from qdrant_client import QdrantClient, models + +except ImportError as e: + raise ImportError( + "qdrant_client library is not installed. Please install it to proceed." + ) from e + def compute_mdhash_id_for_qdrant( content: str, prefix: str = "", style: str = "simple" @@ -48,10 +51,9 @@ def compute_mdhash_id_for_qdrant( raise ValueError("Invalid style. Choose from 'simple', 'hyphenated', or 'urn'.") +@final @dataclass class QdrantVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = None - @staticmethod def create_collection_if_not_exist( client: QdrantClient, collection_name: str, **kwargs diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 27850d81..69a6da2a 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -1,24 +1,26 @@ import asyncio import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Union, final import numpy as np -import pipmaster as pm - -if not pm.is_installed("pymysql"): - pm.install("pymysql") -if not pm.is_installed("sqlalchemy"): - pm.install("sqlalchemy") from lightrag.types import KnowledgeGraph -from sqlalchemy import create_engine, text + from tqdm import tqdm from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..namespace import NameSpace, is_namespace from ..utils import logger +try: + from sqlalchemy import create_engine, text + +except ImportError as e: + raise ImportError( + "pymysql, sqlalchemy library is not installed. Please install it to proceed." + ) from e + class TiDB: def __init__(self, config, **kwargs): @@ -100,6 +102,7 @@ class TiDB: raise +@final @dataclass class TiDBKVStorage(BaseKVStorage): # db instance must be injected before use @@ -200,23 +203,16 @@ class TiDBKVStorage(BaseKVStorage): await self.db.execute(merge_sql, data) return left_data - async def index_done_callback(self): - if is_namespace( - self.namespace, - (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), - ): - logger.info("full doc and chunk data had been saved into TiDB db!") + async def index_done_callback(self) -> None: + pass async def drop(self) -> None: raise NotImplementedError +@final @dataclass class TiDBVectorDBStorage(BaseVectorStorage): - # db instance must be injected before use - # db: TiDB - cosine_better_than_threshold: float = None - def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" @@ -343,7 +339,11 @@ class TiDBVectorDBStorage(BaseVectorStorage): """Delete relations for a given entity by scanning metadata""" raise NotImplementedError + async def index_done_callback(self) -> None: + raise NotImplementedError + +@final @dataclass class TiDBGraphStorage(BaseGraphStorage): # db instance must be injected before use @@ -481,6 +481,9 @@ class TiDBGraphStorage(BaseGraphStorage): else: return [] + async def index_done_callback(self) -> None: + pass + async def delete_node(self, node_id: str) -> None: raise NotImplementedError