change init=False to default=None for db fields to prevent no attribute error

This commit is contained in:
ArnoChen
2025-02-19 04:55:59 +08:00
parent f3b030e1a3
commit bfc548edf2
4 changed files with 18 additions and 18 deletions

View File

@@ -83,8 +83,8 @@ class ClientManager:
@final @final
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
db: AsyncIOMotorDatabase = field(init=False) db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(init=False) _data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
@@ -156,8 +156,8 @@ class MongoKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class MongoDocStatusStorage(DocStatusStorage): class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(init=False) db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(init=False) _data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
@@ -236,8 +236,8 @@ class MongoGraphStorage(BaseGraphStorage):
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries. A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
""" """
db: AsyncIOMotorDatabase = field(init=False) db: AsyncIOMotorDatabase = field(default=None)
collection: AsyncIOMotorCollection = field(init=False) collection: AsyncIOMotorCollection = field(default=None)
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
@@ -798,8 +798,8 @@ class MongoGraphStorage(BaseGraphStorage):
@final @final
@dataclass @dataclass
class MongoVectorDBStorage(BaseVectorStorage): class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(init=False) db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(init=False) _data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})

View File

@@ -242,7 +242,7 @@ class ClientManager:
@final @final
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
db: OracleDB = field(init=False) db: OracleDB = field(default=None)
meta_fields = None meta_fields = None
def __post_init__(self): def __post_init__(self):
@@ -394,7 +394,7 @@ class OracleKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(init=False) db: OracleDB = field(default=None)
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -450,7 +450,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
db: OracleDB = field(init=False) db: OracleDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)

View File

@@ -246,7 +246,7 @@ class ClientManager:
@final @final
@dataclass @dataclass
class PGKVStorage(BaseKVStorage): class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = field(init=False) db: PostgreSQLDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -378,7 +378,7 @@ class PGKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(init=False) db: PostgreSQLDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -515,7 +515,7 @@ class PGVectorStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
db: PostgreSQLDB = field(init=False) db: PostgreSQLDB = field(default=None)
async def initialize(self): async def initialize(self):
if self.db is None: if self.db is None:
@@ -665,7 +665,7 @@ class PGGraphQueryException(Exception):
@final @final
@dataclass @dataclass
class PGGraphStorage(BaseGraphStorage): class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(init=False) db: PostgreSQLDB = field(default=None)
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):

View File

@@ -166,7 +166,7 @@ class ClientManager:
@final @final
@dataclass @dataclass
class TiDBKVStorage(BaseKVStorage): class TiDBKVStorage(BaseKVStorage):
db: TiDB = field(init=False) db: TiDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
@@ -279,7 +279,7 @@ class TiDBKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB = field(init=False) db: TiDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
@@ -422,7 +422,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
db: TiDB = field(init=False) db: TiDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]