refactor: make cosine similarity threshold a required config parameter

• Remove default threshold from env var
• Add validation for missing threshold
• Move default to lightrag.py config init
• Update all vector DB implementations
• Improve threshold validation consistency
This commit is contained in:
yangdx
2025-02-13 03:25:48 +08:00
parent 3308ecfa69
commit f01f57d0da
9 changed files with 59 additions and 30 deletions

View File

@@ -13,15 +13,15 @@ from lightrag.utils import logger
class ChromaVectorDBStorage(BaseVectorStorage): class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation.""" """ChromaDB vector storage implementation."""
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
try: try:
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
self.cosine_better_than_threshold = cosine_threshold
user_collection_settings = config.get("collection_settings", {}) user_collection_settings = config.get("collection_settings", {})
# Default HNSW index settings for ChromaDB # Default HNSW index settings for ChromaDB

View File

@@ -23,14 +23,15 @@ class FaissVectorDBStorage(BaseVectorStorage):
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
""" """
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Grab config values if available # Grab config values if available
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
self.cosine_better_than_threshold = cosine_threshold
# Where to save index file if you want persistent storage # Where to save index file if you want persistent storage
self._faiss_index_file = os.path.join( self._faiss_index_file = os.path.join(

View File

@@ -19,6 +19,8 @@ config.read("config.ini", "utf-8")
@dataclass @dataclass
class MilvusVectorDBStorge(BaseVectorStorage): class MilvusVectorDBStorge(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod @staticmethod
def create_collection_if_not_exist( def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs client: MilvusClient, collection_name: str, **kwargs
@@ -30,6 +32,12 @@ class MilvusVectorDBStorge(BaseVectorStorage):
) )
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
self.cosine_better_than_threshold = cosine_threshold
self._client = MilvusClient( self._client = MilvusClient(
uri=os.environ.get( uri=os.environ.get(
"MILVUS_URI", "MILVUS_URI",
@@ -103,7 +111,7 @@ class MilvusVectorDBStorge(BaseVectorStorage):
data=embedding, data=embedding,
limit=top_k, limit=top_k,
output_fields=list(self.meta_fields), output_fields=list(self.meta_fields),
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}}, search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}},
) )
print(results) print(results)
return [ return [

View File

@@ -73,16 +73,17 @@ from lightrag.base import (
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations # Initialize lock only for file operations
self._save_lock = asyncio.Lock() self._save_lock = asyncio.Lock()
# Use global config value if specified, otherwise use default # Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
self.cosine_better_than_threshold = cosine_threshold
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"

View File

@@ -320,14 +320,14 @@ class OracleKVStorage(BaseKVStorage):
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use # db instance must be injected before use
# db: OracleDB # db: OracleDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
self.cosine_better_than_threshold = cosine_threshold
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
"""向向量数据库中插入数据""" """向向量数据库中插入数据"""

View File

@@ -299,15 +299,15 @@ class PGKVStorage(BaseKVStorage):
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use # db instance must be injected before use
# db: PostgreSQLDB # db: PostgreSQLDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
self.cosine_better_than_threshold = cosine_threshold
def _upsert_chunks(self, item: dict): def _upsert_chunks(self, item: dict):
try: try:

View File

@@ -50,6 +50,8 @@ def compute_mdhash_id_for_qdrant(
@dataclass @dataclass
class QdrantVectorDBStorage(BaseVectorStorage): class QdrantVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod @staticmethod
def create_collection_if_not_exist( def create_collection_if_not_exist(
client: QdrantClient, collection_name: str, **kwargs client: QdrantClient, collection_name: str, **kwargs
@@ -59,6 +61,12 @@ class QdrantVectorDBStorage(BaseVectorStorage):
client.create_collection(collection_name, **kwargs) client.create_collection(collection_name, **kwargs)
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
self.cosine_better_than_threshold = cosine_threshold
self._client = QdrantClient( self._client = QdrantClient(
url=os.environ.get( url=os.environ.get(
"QDRANT_URL", config.get("qdrant", "uri", fallback=None) "QDRANT_URL", config.get("qdrant", "uri", fallback=None)
@@ -131,4 +139,6 @@ class QdrantVectorDBStorage(BaseVectorStorage):
with_payload=True, with_payload=True,
) )
logger.debug(f"query result: {results}") logger.debug(f"query result: {results}")
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] # 添加余弦相似度过滤
filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold]
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results]

View File

@@ -212,18 +212,18 @@ class TiDBKVStorage(BaseKVStorage):
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use # db instance must be injected before use
# db: TiDB # db: TiDB
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) cosine_better_than_threshold: float = None
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get( cosine_threshold = config.get("cosine_better_than_threshold")
"cosine_better_than_threshold", self.cosine_better_than_threshold if cosine_threshold is None:
) raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
self.cosine_better_than_threshold = cosine_threshold
async def query(self, query: str, top_k: int) -> list[dict]: async def query(self, query: str, top_k: int) -> list[dict]:
"""Search from tidb vector""" """Search from tidb vector"""

View File

@@ -420,6 +420,15 @@ class LightRAG:
# Check environment variables # Check environment variables
self.check_storage_env_vars(storage_name) self.check_storage_env_vars(storage_name)
# Ensure vector_db_storage_cls_kwargs has required fields
default_vector_db_kwargs = {
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
}
self.vector_db_storage_cls_kwargs = {
**default_vector_db_kwargs,
**self.vector_db_storage_cls_kwargs
}
# show config # show config
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])