diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 2d1f631c..368e8618 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -1,3 +1,4 @@ +import os import asyncio # import html @@ -341,10 +342,14 @@ class OracleKVStorage(BaseKVStorage): class OracleVectorDBStorage(BaseVectorStorage): # should pass db object to self.db db: OracleDB = None - cosine_better_than_threshold: float = 0.2 + cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): - pass + # Use global config value if specified, otherwise use default + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + self.cosine_better_than_threshold = config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) async def upsert(self, data: dict[str, dict]): """向向量数据库中插入数据""" diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 57fe8d8d..b315abca 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -301,12 +301,14 @@ class PGKVStorage(BaseKVStorage): @dataclass class PGVectorStorage(BaseVectorStorage): - cosine_better_than_threshold: float = 0.2 + cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) db: PostgreSQLDB = None def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] - self.cosine_better_than_threshold = self.global_config.get( + # Use global config value if specified, otherwise use default + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + self.cosine_better_than_threshold = config.get( "cosine_better_than_threshold", self.cosine_better_than_threshold ) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index d76c2c99..0579a57c 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -217,14 +217,16 @@ class TiDBKVStorage(BaseKVStorage): @dataclass class TiDBVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = 0.2 + cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] - self.cosine_better_than_threshold = self.global_config.get( + # Use global config value if specified, otherwise use default + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + self.cosine_better_than_threshold = config.get( "cosine_better_than_threshold", self.cosine_better_than_threshold )