refactor: improve database initialization by centralizing db instance injection
- Move db configs to separate methods - Remove db field defaults in storage classes - Add _initialize_database_if_needed method - Inject db instances during initialization - Clean up storage implementation code
This commit is contained in:
@@ -172,8 +172,8 @@ class OracleDB:
|
||||
|
||||
@dataclass
|
||||
class OracleKVStorage(BaseKVStorage):
|
||||
# should pass db object to self.db
|
||||
db: OracleDB = None
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
meta_fields = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -318,8 +318,8 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# should pass db object to self.db
|
||||
db: OracleDB = None
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -361,8 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
@dataclass
|
||||
class OracleGraphStorage(BaseGraphStorage):
|
||||
# should pass db object to self.db
|
||||
db: OracleDB = None
|
||||
# db instance must be injected before use
|
||||
# db: OracleDB
|
||||
|
||||
def __post_init__(self):
|
||||
"""从graphml文件加载图"""
|
||||
|
@@ -177,7 +177,8 @@ class PostgreSQLDB:
|
||||
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
db: PostgreSQLDB = None
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
@@ -296,8 +297,9 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
@@ -418,10 +420,8 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
"""PostgreSQL implementation of document status storage"""
|
||||
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# db instance must be injected before use
|
||||
db: PostgreSQLDB
|
||||
|
||||
async def filter_keys(self, data: set[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
@@ -577,19 +577,15 @@ class PGGraphQueryException(Exception):
|
||||
|
||||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
db: PostgreSQLDB = None
|
||||
# db instance must be injected before use
|
||||
# db: PostgreSQLDB
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with AGE in production")
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.graph_name = os.environ["AGE_GRAPH_NAME"]
|
||||
def __post_init__(self):
|
||||
self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag")
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
@@ -101,8 +101,8 @@ class TiDB:
|
||||
|
||||
@dataclass
|
||||
class TiDBKVStorage(BaseKVStorage):
|
||||
# should pass db object to self.db
|
||||
db: TiDB = None
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
@@ -210,8 +210,8 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
# should pass db object to self.db
|
||||
db: TiDB = None
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -333,8 +333,8 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
# should pass db object to self.db
|
||||
db: TiDB = None
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
Reference in New Issue
Block a user