refactor: improve database initialization by centralizing db instance injection

- Move db configs to separate methods
- Remove db field defaults in storage classes
- Add _initialize_database_if_needed method
- Inject db instances during initialization
- Clean up storage implementation code
This commit is contained in:
yangdx
2025-02-12 22:25:34 +08:00
parent cf61bed62c
commit 7b79427097
4 changed files with 205 additions and 200 deletions

View File

@@ -177,7 +177,8 @@ class PostgreSQLDB:
@dataclass
class PGKVStorage(BaseKVStorage):
db: PostgreSQLDB = None
# db instance must be injected before use
# db: PostgreSQLDB
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -296,8 +297,9 @@ class PGKVStorage(BaseKVStorage):
@dataclass
class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use
# db: PostgreSQLDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
db: PostgreSQLDB = None
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -418,10 +420,8 @@ class PGVectorStorage(BaseVectorStorage):
class PGDocStatusStorage(DocStatusStorage):
"""PostgreSQL implementation of document status storage"""
db: PostgreSQLDB = None
def __post_init__(self):
pass
# db instance must be injected before use
db: PostgreSQLDB
async def filter_keys(self, data: set[str]) -> set[str]:
"""Return keys that don't exist in storage"""
@@ -577,19 +577,15 @@ class PGGraphQueryException(Exception):
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = None
# db instance must be injected before use
# db: PostgreSQLDB
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self.graph_name = os.environ["AGE_GRAPH_NAME"]
def __post_init__(self):
self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag")
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}