diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 200e780c..72a2627a 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,3 +1,4 @@ +import os import asyncio from dataclasses import dataclass from typing import Union @@ -12,16 +13,16 @@ from lightrag.utils import logger class ChromaVectorDBStorage(BaseVectorStorage): """ChromaDB vector storage implementation.""" - cosine_better_than_threshold: float = 0.2 + cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): try: # Use global config value if specified, otherwise use default - self.cosine_better_than_threshold = self.global_config.get( + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + self.cosine_better_than_threshold = config.get( "cosine_better_than_threshold", self.cosine_better_than_threshold ) - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) user_collection_settings = config.get("collection_settings", {}) # Default HNSW index settings for ChromaDB default_collection_settings = { diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index e42036b5..ed272fee 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -76,6 +76,12 @@ class NanoVectorDBStorage(BaseVectorStorage): cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): + # Use global config value if specified, otherwise use default + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + self.cosine_better_than_threshold = config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) @@ -83,14 +89,6 @@ class NanoVectorDBStorage(BaseVectorStorage): self._client = NanoVectorDB( self.embedding_func.embedding_dim, storage_file=self._client_file_name ) - # get cosine_better_than_threshold from LightRAG - vector_db_kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = vector_db_kwargs.get( - "cosine_better_than_threshold", - self.global_config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ), - ) async def upsert(self, data: dict[str, dict]): logger.info(f"Inserting {len(data)} vectors to {self.namespace}")