refactor: improve database initialization by centralizing db instance injection

- Move db configs to separate methods
- Remove db field defaults in storage classes
- Add _initialize_database_if_needed method
- Inject db instances during initialization
- Clean up storage implementation code
This commit is contained in:
yangdx
2025-02-12 22:25:34 +08:00
parent cf61bed62c
commit 7b79427097
4 changed files with 205 additions and 200 deletions

View File

@@ -172,8 +172,8 @@ class OracleDB:
@dataclass
class OracleKVStorage(BaseKVStorage):
# should pass db object to self.db
db: OracleDB = None
# db instance must be injected before use
# db: OracleDB
meta_fields = None
def __post_init__(self):
@@ -318,8 +318,8 @@ class OracleKVStorage(BaseKVStorage):
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db
db: OracleDB = None
# db instance must be injected before use
# db: OracleDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
@@ -361,8 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage):
@dataclass
class OracleGraphStorage(BaseGraphStorage):
# should pass db object to self.db
db: OracleDB = None
# db instance must be injected before use
# db: OracleDB
def __post_init__(self):
"""从graphml文件加载图"""

View File

@@ -177,7 +177,8 @@ class PostgreSQLDB:
@dataclass
class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = None
# db instance must be injected before use
# db: PostgreSQLDB
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -296,8 +297,9 @@ class PGKVStorage(BaseKVStorage):
@dataclass
class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use
# db: PostgreSQLDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
db: PostgreSQLDB = None
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -418,10 +420,8 @@ class PGVectorStorage(BaseVectorStorage):
class PGDocStatusStorage(DocStatusStorage):
"""PostgreSQL implementation of document status storage"""
db: PostgreSQLDB = None
def __post_init__(self):
pass
# db instance must be injected before use
db: PostgreSQLDB
async def filter_keys(self, data: set[str]) -> set[str]:
"""Return keys that don't exist in storage"""
@@ -577,19 +577,15 @@ class PGGraphQueryException(Exception):
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = None
# db instance must be injected before use
# db: PostgreSQLDB
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self.graph_name = os.environ["AGE_GRAPH_NAME"]
def __post_init__(self):
self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag")
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}

View File

@@ -101,8 +101,8 @@ class TiDB:
@dataclass
class TiDBKVStorage(BaseKVStorage):
# should pass db object to self.db
db: TiDB = None
# db instance must be injected before use
# db: TiDB
def __post_init__(self):
self._data = {}
@@ -210,8 +210,8 @@ class TiDBKVStorage(BaseKVStorage):
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db
db: TiDB = None
# db instance must be injected before use
# db: TiDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
@@ -333,8 +333,8 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# should pass db object to self.db
db: TiDB = None
# db instance must be injected before use
# db: TiDB
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]