Refactor threshold handling to use environment variables and global config settings for oracle, postgres and tidb

This commit is contained in:
yangdx
2025-01-29 23:47:57 +08:00
parent 46c9c7d95b
commit 06647438b2
3 changed files with 15 additions and 6 deletions

View File

@@ -1,3 +1,4 @@
import os
import asyncio import asyncio
# import html # import html
@@ -341,10 +342,14 @@ class OracleKVStorage(BaseKVStorage):
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db # should pass db object to self.db
db: OracleDB = None db: OracleDB = None
cosine_better_than_threshold: float = 0.2 cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self): def __post_init__(self):
pass # Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
"""向向量数据库中插入数据""" """向向量数据库中插入数据"""

View File

@@ -301,12 +301,14 @@ class PGKVStorage(BaseKVStorage):
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2 cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
db: PostgreSQLDB = None db: PostgreSQLDB = None
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
self.cosine_better_than_threshold = self.global_config.get( # Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold "cosine_better_than_threshold", self.cosine_better_than_threshold
) )

View File

@@ -217,14 +217,16 @@ class TiDBKVStorage(BaseKVStorage):
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2 cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
self.cosine_better_than_threshold = self.global_config.get( # Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold "cosine_better_than_threshold", self.cosine_better_than_threshold
) )