refactor: make cosine similarity threshold a required config parameter
• Remove default threshold from env var • Add validation for missing threshold • Move default to lightrag.py config init • Update all vector DB implementations • Improve threshold validation consistency
This commit is contained in:
@@ -13,15 +13,15 @@ from lightrag.utils import logger
|
|||||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||||
"""ChromaDB vector storage implementation."""
|
"""ChromaDB vector storage implementation."""
|
||||||
|
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
try:
|
try:
|
||||||
# Use global config value if specified, otherwise use default
|
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
user_collection_settings = config.get("collection_settings", {})
|
user_collection_settings = config.get("collection_settings", {})
|
||||||
# Default HNSW index settings for ChromaDB
|
# Default HNSW index settings for ChromaDB
|
||||||
|
@@ -23,14 +23,15 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Grab config values if available
|
# Grab config values if available
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
# Where to save index file if you want persistent storage
|
# Where to save index file if you want persistent storage
|
||||||
self._faiss_index_file = os.path.join(
|
self._faiss_index_file = os.path.join(
|
||||||
|
@@ -19,6 +19,8 @@ config.read("config.ini", "utf-8")
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MilvusVectorDBStorge(BaseVectorStorage):
|
class MilvusVectorDBStorge(BaseVectorStorage):
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_collection_if_not_exist(
|
def create_collection_if_not_exist(
|
||||||
client: MilvusClient, collection_name: str, **kwargs
|
client: MilvusClient, collection_name: str, **kwargs
|
||||||
@@ -30,6 +32,12 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
|
if cosine_threshold is None:
|
||||||
|
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client = MilvusClient(
|
self._client = MilvusClient(
|
||||||
uri=os.environ.get(
|
uri=os.environ.get(
|
||||||
"MILVUS_URI",
|
"MILVUS_URI",
|
||||||
@@ -103,7 +111,7 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
|||||||
data=embedding,
|
data=embedding,
|
||||||
limit=top_k,
|
limit=top_k,
|
||||||
output_fields=list(self.meta_fields),
|
output_fields=list(self.meta_fields),
|
||||||
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
|
search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}},
|
||||||
)
|
)
|
||||||
print(results)
|
print(results)
|
||||||
return [
|
return [
|
||||||
|
@@ -73,16 +73,17 @@ from lightrag.base import (
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NanoVectorDBStorage(BaseVectorStorage):
|
class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize lock only for file operations
|
# Initialize lock only for file operations
|
||||||
self._save_lock = asyncio.Lock()
|
self._save_lock = asyncio.Lock()
|
||||||
# Use global config value if specified, otherwise use default
|
# Use global config value if specified, otherwise use default
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
|
@@ -320,14 +320,14 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
# db: OracleDB
|
# db: OracleDB
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Use global config value if specified, otherwise use default
|
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict]):
|
async def upsert(self, data: dict[str, dict]):
|
||||||
"""向向量数据库中插入数据"""
|
"""向向量数据库中插入数据"""
|
||||||
|
@@ -299,15 +299,15 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
class PGVectorStorage(BaseVectorStorage):
|
class PGVectorStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
# db: PostgreSQLDB
|
# db: PostgreSQLDB
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
# Use global config value if specified, otherwise use default
|
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
def _upsert_chunks(self, item: dict):
|
def _upsert_chunks(self, item: dict):
|
||||||
try:
|
try:
|
||||||
|
@@ -50,6 +50,8 @@ def compute_mdhash_id_for_qdrant(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QdrantVectorDBStorage(BaseVectorStorage):
|
class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_collection_if_not_exist(
|
def create_collection_if_not_exist(
|
||||||
client: QdrantClient, collection_name: str, **kwargs
|
client: QdrantClient, collection_name: str, **kwargs
|
||||||
@@ -59,6 +61,12 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
client.create_collection(collection_name, **kwargs)
|
client.create_collection(collection_name, **kwargs)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
|
if cosine_threshold is None:
|
||||||
|
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
self._client = QdrantClient(
|
self._client = QdrantClient(
|
||||||
url=os.environ.get(
|
url=os.environ.get(
|
||||||
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
|
"QDRANT_URL", config.get("qdrant", "uri", fallback=None)
|
||||||
@@ -131,4 +139,6 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
with_payload=True,
|
with_payload=True,
|
||||||
)
|
)
|
||||||
logger.debug(f"query result: {results}")
|
logger.debug(f"query result: {results}")
|
||||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
# 添加余弦相似度过滤
|
||||||
|
filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold]
|
||||||
|
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results]
|
||||||
|
@@ -212,18 +212,18 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
# db instance must be injected before use
|
# db instance must be injected before use
|
||||||
# db: TiDB
|
# db: TiDB
|
||||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
cosine_better_than_threshold: float = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
# Use global config value if specified, otherwise use default
|
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
self.cosine_better_than_threshold = config.get(
|
cosine_threshold = config.get("cosine_better_than_threshold")
|
||||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
if cosine_threshold is None:
|
||||||
)
|
raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
|
||||||
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||||
"""Search from tidb vector"""
|
"""Search from tidb vector"""
|
||||||
|
@@ -420,6 +420,15 @@ class LightRAG:
|
|||||||
# Check environment variables
|
# Check environment variables
|
||||||
self.check_storage_env_vars(storage_name)
|
self.check_storage_env_vars(storage_name)
|
||||||
|
|
||||||
|
# Ensure vector_db_storage_cls_kwargs has required fields
|
||||||
|
default_vector_db_kwargs = {
|
||||||
|
"cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||||
|
}
|
||||||
|
self.vector_db_storage_cls_kwargs = {
|
||||||
|
**default_vector_db_kwargs,
|
||||||
|
**self.vector_db_storage_cls_kwargs
|
||||||
|
}
|
||||||
|
|
||||||
# show config
|
# show config
|
||||||
global_config = asdict(self)
|
global_config = asdict(self)
|
||||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||||
|
Reference in New Issue
Block a user