diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 92a0d9db..a6a33084 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,5 +1,5 @@ import os -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np import configparser from tqdm.asyncio import tqdm as tqdm_async @@ -27,7 +27,11 @@ if not pm.is_installed("motor"): pm.install("motor") try: - from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase + from motor.motor_asyncio import ( + AsyncIOMotorClient, + AsyncIOMotorDatabase, + AsyncIOMotorCollection, + ) from pymongo.operations import SearchIndexModel from pymongo.errors import PyMongoError except ImportError as e: @@ -79,19 +83,23 @@ class ClientManager: @final @dataclass class MongoKVStorage(BaseKVStorage): + db: AsyncIOMotorDatabase = field(init=False) + _data: AsyncIOMotorCollection = field(init=False) + def __post_init__(self): self._collection_name = self.namespace async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() self._data = await get_or_create_collection(self.db, self._collection_name) logger.debug(f"Use MongoDB as KV {self._collection_name}") async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None + self._data = None async def get_by_id(self, id: str) -> dict[str, Any] | None: return await self._data.find_one({"_id": id}) @@ -148,19 +156,23 @@ class MongoKVStorage(BaseKVStorage): @final @dataclass class MongoDocStatusStorage(DocStatusStorage): + db: AsyncIOMotorDatabase = field(init=False) + _data: AsyncIOMotorCollection = field(init=False) + def __post_init__(self): self._collection_name = self.namespace async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() self._data = await get_or_create_collection(self.db, self._collection_name) logger.debug(f"Use MongoDB as DocStatus {self._collection_name}") async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None + self._data = None async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return await self._data.find_one({"_id": id}) @@ -221,9 +233,12 @@ class MongoDocStatusStorage(DocStatusStorage): @dataclass class MongoGraphStorage(BaseGraphStorage): """ - A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries. + A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries. """ + db: AsyncIOMotorDatabase = field(init=False) + collection: AsyncIOMotorCollection = field(init=False) + def __init__(self, namespace, global_config, embedding_func): super().__init__( namespace=namespace, @@ -233,7 +248,7 @@ class MongoGraphStorage(BaseGraphStorage): self._collection_name = self.namespace async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() self.collection = await get_or_create_collection( self.db, self._collection_name @@ -241,9 +256,10 @@ class MongoGraphStorage(BaseGraphStorage): logger.debug(f"Use MongoDB as KG {self._collection_name}") async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None + self.collection = None # # ------------------------------------------------------------------------- @@ -782,6 +798,9 @@ class MongoGraphStorage(BaseGraphStorage): @final @dataclass class MongoVectorDBStorage(BaseVectorStorage): + db: AsyncIOMotorDatabase = field(init=False) + _data: AsyncIOMotorCollection = field(init=False) + def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -794,7 +813,7 @@ class MongoVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() self._data = await get_or_create_collection(self.db, self._collection_name) @@ -804,9 +823,10 @@ class MongoVectorDBStorage(BaseVectorStorage): logger.debug(f"Use MongoDB as VDB {self._collection_name}") async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None + self._data = None async def create_vector_index_if_not_exists(self): """Creates an Atlas Vector Search index.""" diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 8391acaa..70c5258c 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -3,7 +3,7 @@ import asyncio # import html import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Union, final import numpy as np import configparser @@ -242,8 +242,7 @@ class ClientManager: @final @dataclass class OracleKVStorage(BaseKVStorage): - # db instance must be injected before use - # db: OracleDB + db: OracleDB = field(init=False) meta_fields = None def __post_init__(self): @@ -251,11 +250,11 @@ class OracleKVStorage(BaseKVStorage): self._max_batch_size = self.global_config.get("embedding_batch_num", 10) async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -395,6 +394,8 @@ class OracleKVStorage(BaseKVStorage): @final @dataclass class OracleVectorDBStorage(BaseVectorStorage): + db: OracleDB = field(init=False) + def __post_init__(self): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") @@ -405,11 +406,11 @@ class OracleVectorDBStorage(BaseVectorStorage): self.cosine_better_than_threshold = cosine_threshold async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -449,15 +450,17 @@ class OracleVectorDBStorage(BaseVectorStorage): @final @dataclass class OracleGraphStorage(BaseGraphStorage): + db: OracleDB = field(init=False) + def __post_init__(self): self._max_batch_size = self.global_config.get("embedding_batch_num", 10) async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 2ec16716..654dc6dd 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -3,7 +3,7 @@ import inspect import json import os import time -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, List, Union, final import numpy as np import configparser @@ -246,18 +246,17 @@ class ClientManager: @final @dataclass class PGKVStorage(BaseKVStorage): - # db instance must be injected before use - # db: PostgreSQLDB + db: PostgreSQLDB = field(init=False) def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -379,6 +378,8 @@ class PGKVStorage(BaseKVStorage): @final @dataclass class PGVectorStorage(BaseVectorStorage): + db: PostgreSQLDB = field(init=False) + def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] config = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -390,11 +391,11 @@ class PGVectorStorage(BaseVectorStorage): self.cosine_better_than_threshold = cosine_threshold async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -514,12 +515,14 @@ class PGVectorStorage(BaseVectorStorage): @final @dataclass class PGDocStatusStorage(DocStatusStorage): + db: PostgreSQLDB = field(init=False) + async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -662,6 +665,8 @@ class PGGraphQueryException(Exception): @final @dataclass class PGGraphStorage(BaseGraphStorage): + db: PostgreSQLDB = field(init=False) + @staticmethod def load_nx_graph(file_name): print("no preloading of graph with AGE in production") @@ -673,11 +678,11 @@ class PGGraphStorage(BaseGraphStorage): } async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index dc0dc422..e5c6995e 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -1,6 +1,6 @@ import asyncio import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Union, final import numpy as np @@ -166,19 +166,18 @@ class ClientManager: @final @dataclass class TiDBKVStorage(BaseKVStorage): - # db instance must be injected before use - # db: TiDB + db: TiDB = field(init=False) def __post_init__(self): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -280,6 +279,8 @@ class TiDBKVStorage(BaseKVStorage): @final @dataclass class TiDBVectorDBStorage(BaseVectorStorage): + db: TiDB = field(init=False) + def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" @@ -294,11 +295,11 @@ class TiDBVectorDBStorage(BaseVectorStorage): self.cosine_better_than_threshold = cosine_threshold async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -421,18 +422,17 @@ class TiDBVectorDBStorage(BaseVectorStorage): @final @dataclass class TiDBGraphStorage(BaseGraphStorage): - # db instance must be injected before use - # db: TiDB + db: TiDB = field(init=False) def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): - if not hasattr(self, "db") or self.db is None: + if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): - if hasattr(self, "db") and self.db is not None: + if self.db is not None: await ClientManager.release_client(self.db) self.db = None