diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 72a2627a..242c93ea 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -13,15 +13,15 @@ from lightrag.utils import logger class ChromaVectorDBStorage(BaseVectorStorage): """ChromaDB vector storage implementation.""" - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): try: - # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold user_collection_settings = config.get("collection_settings", {}) # Default HNSW index settings for ChromaDB diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index fc6aa779..47111a47 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -23,14 +23,15 @@ class FaissVectorDBStorage(BaseVectorStorage): Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. """ - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): # Grab config values if available config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold # Where to save index file if you want persistent storage self._faiss_index_file = os.path.join( diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index ae0daac2..dd50c026 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -19,6 +19,8 @@ config.read("config.ini", "utf-8") @dataclass class MilvusVectorDBStorge(BaseVectorStorage): + cosine_better_than_threshold: float = None + @staticmethod def create_collection_if_not_exist( client: MilvusClient, collection_name: str, **kwargs @@ -30,6 +32,12 @@ class MilvusVectorDBStorge(BaseVectorStorage): ) def __post_init__(self): + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold + self._client = MilvusClient( uri=os.environ.get( "MILVUS_URI", @@ -103,7 +111,7 @@ class MilvusVectorDBStorge(BaseVectorStorage): data=embedding, limit=top_k, output_fields=list(self.meta_fields), - search_params={"metric_type": "COSINE", "params": {"radius": 0.2}}, + search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}}, ) print(results) return [ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 1cbd1b0b..5a61bf4f 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -73,16 +73,17 @@ from lightrag.base import ( @dataclass class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): # Initialize lock only for file operations self._save_lock = asyncio.Lock() # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index c2859829..5a1e0616 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -320,14 +320,14 @@ class OracleKVStorage(BaseKVStorage): class OracleVectorDBStorage(BaseVectorStorage): # db instance must be injected before use # db: OracleDB - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): - # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold async def upsert(self, data: dict[str, dict]): """向向量数据库中插入数据""" diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 4b6f524f..dde88739 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -299,15 +299,15 @@ class PGKVStorage(BaseKVStorage): class PGVectorStorage(BaseVectorStorage): # db instance must be injected before use # db: PostgreSQLDB - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] - # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold def _upsert_chunks(self, item: dict): try: diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index bda23f8d..88dce27f 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -50,6 +50,8 @@ def compute_mdhash_id_for_qdrant( @dataclass class QdrantVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = None + @staticmethod def create_collection_if_not_exist( client: QdrantClient, collection_name: str, **kwargs @@ -59,6 +61,12 @@ class QdrantVectorDBStorage(BaseVectorStorage): client.create_collection(collection_name, **kwargs) def __post_init__(self): + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold + self._client = QdrantClient( url=os.environ.get( "QDRANT_URL", config.get("qdrant", "uri", fallback=None) @@ -131,4 +139,6 @@ class QdrantVectorDBStorage(BaseVectorStorage): with_payload=True, ) logger.debug(f"query result: {results}") - return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] + # 添加余弦相似度过滤 + filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold] + return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results] diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index ba5a6240..248f2c85 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -212,18 +212,18 @@ class TiDBKVStorage(BaseKVStorage): class TiDBVectorDBStorage(BaseVectorStorage): # db instance must be injected before use # db: TiDB - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] - # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold async def query(self, query: str, top_k: int) -> list[dict]: """Search from tidb vector""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e6217572..66508faf 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -420,6 +420,15 @@ class LightRAG: # Check environment variables self.check_storage_env_vars(storage_name) + # Ensure vector_db_storage_cls_kwargs has required fields + default_vector_db_kwargs = { + "cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2")) + } + self.vector_db_storage_cls_kwargs = { + **default_vector_db_kwargs, + **self.vector_db_storage_cls_kwargs + } + # show config global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])