improve conditional checks for db instance

This commit is contained in:
ArnoChen
2025-02-19 04:53:15 +08:00
parent ae7a850d4e
commit f50604b2d3
4 changed files with 70 additions and 42 deletions

View File

@@ -3,7 +3,7 @@ import inspect
import json
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final
import numpy as np
import configparser
@@ -246,18 +246,17 @@ class ClientManager:
@final
@dataclass
class PGKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: PostgreSQLDB
db: PostgreSQLDB = field(init=False)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -379,6 +378,8 @@ class PGKVStorage(BaseKVStorage):
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(init=False)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -390,11 +391,11 @@ class PGVectorStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -514,12 +515,14 @@ class PGVectorStorage(BaseVectorStorage):
@final
@dataclass
class PGDocStatusStorage(DocStatusStorage):
db: PostgreSQLDB = field(init=False)
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
@@ -662,6 +665,8 @@ class PGGraphQueryException(Exception):
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(init=False)
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
@@ -673,11 +678,11 @@ class PGGraphStorage(BaseGraphStorage):
}
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None