diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py new file mode 100644 index 00000000..200e780c --- /dev/null +++ b/lightrag/kg/chroma_impl.py @@ -0,0 +1,172 @@ +import asyncio +from dataclasses import dataclass +from typing import Union +import numpy as np +from chromadb import HttpClient +from chromadb.config import Settings +from lightrag.base import BaseVectorStorage +from lightrag.utils import logger + + +@dataclass +class ChromaVectorDBStorage(BaseVectorStorage): + """ChromaDB vector storage implementation.""" + + cosine_better_than_threshold: float = 0.2 + + def __post_init__(self): + try: + # Use global config value if specified, otherwise use default + self.cosine_better_than_threshold = self.global_config.get( + "cosine_better_than_threshold", self.cosine_better_than_threshold + ) + + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + user_collection_settings = config.get("collection_settings", {}) + # Default HNSW index settings for ChromaDB + default_collection_settings = { + # Distance metric used for similarity search (cosine similarity) + "hnsw:space": "cosine", + # Number of nearest neighbors to explore during index construction + # Higher values = better recall but slower indexing + "hnsw:construction_ef": 128, + # Number of nearest neighbors to explore during search + # Higher values = better recall but slower search + "hnsw:search_ef": 128, + # Number of connections per node in the HNSW graph + # Higher values = better recall but more memory usage + "hnsw:M": 16, + # Number of vectors to process in one batch during indexing + "hnsw:batch_size": 100, + # Number of updates before forcing index synchronization + # Lower values = more frequent syncs but slower indexing + "hnsw:sync_threshold": 1000, + } + collection_settings = { + **default_collection_settings, + **user_collection_settings, + } + + auth_provider = config.get( + "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider" + ) + auth_credentials = config.get("auth_token", "secret-token") + headers = {} + + if "token_authn" in auth_provider: + headers = { + config.get("auth_header_name", "X-Chroma-Token"): auth_credentials + } + elif "basic_authn" in auth_provider: + auth_credentials = config.get("auth_credentials", "admin:admin") + + self._client = HttpClient( + host=config.get("host", "localhost"), + port=config.get("port", 8000), + headers=headers, + settings=Settings( + chroma_api_impl="rest", + chroma_client_auth_provider=auth_provider, + chroma_client_auth_credentials=auth_credentials, + allow_reset=True, + anonymized_telemetry=False, + ), + ) + + self._collection = self._client.get_or_create_collection( + name=self.namespace, + metadata={ + **collection_settings, + "dimension": self.embedding_func.embedding_dim, + }, + ) + # Use batch size from collection settings if specified + self._max_batch_size = self.global_config.get( + "embedding_batch_num", collection_settings.get("hnsw:batch_size", 32) + ) + except Exception as e: + logger.error(f"ChromaDB initialization failed: {str(e)}") + raise + + async def upsert(self, data: dict[str, dict]): + if not data: + logger.warning("Empty data provided to vector DB") + return [] + + try: + ids = list(data.keys()) + documents = [v["content"] for v in data.values()] + metadatas = [ + {k: v for k, v in item.items() if k in self.meta_fields} + or {"_default": "true"} + for item in data.values() + ] + + # Process in batches + batches = [ + documents[i : i + self._max_batch_size] + for i in range(0, len(documents), self._max_batch_size) + ] + + embedding_tasks = [self.embedding_func(batch) for batch in batches] + embeddings_list = [] + + # Pre-allocate embeddings_list with known size + embeddings_list = [None] * len(embedding_tasks) + + # Use asyncio.gather instead of as_completed if order doesn't matter + embeddings_results = await asyncio.gather(*embedding_tasks) + embeddings_list = list(embeddings_results) + + embeddings = np.concatenate(embeddings_list) + + # Upsert in batches + for i in range(0, len(ids), self._max_batch_size): + batch_slice = slice(i, i + self._max_batch_size) + + self._collection.upsert( + ids=ids[batch_slice], + embeddings=embeddings[batch_slice].tolist(), + documents=documents[batch_slice], + metadatas=metadatas[batch_slice], + ) + + return ids + + except Exception as e: + logger.error(f"Error during ChromaDB upsert: {str(e)}") + raise + + async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: + try: + embedding = await self.embedding_func([query]) + + results = self._collection.query( + query_embeddings=embedding.tolist(), + n_results=top_k * 2, # Request more results to allow for filtering + include=["metadatas", "distances", "documents"], + ) + + # Filter results by cosine similarity threshold and take top k + # We request 2x results initially to have enough after filtering + # ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal) + # We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity) + # Only keep results with distance below threshold, then take top k + return [ + { + "id": results["ids"][0][i], + "distance": 1 - results["distances"][0][i], + "content": results["documents"][0][i], + **results["metadatas"][0][i], + } + for i in range(len(results["ids"][0])) + if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold + ][:top_k] + + except Exception as e: + logger.error(f"Error during ChromaDB query: {str(e)}") + raise + + async def index_done_callback(self): + # ChromaDB handles persistence automatically + pass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 833926e5..ff14787f 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -48,18 +48,24 @@ from .storage import ( def lazy_external_import(module_name: str, class_name: str): - """Lazily import an external module and return a class from it.""" + """Lazily import a class from an external module based on the package of the caller.""" - def import_class(): + def import_class(*args, **kwargs): + import inspect import importlib - # Import the module using importlib - module = importlib.import_module(module_name) + # Get the caller's module and package + caller_frame = inspect.currentframe().f_back + module = inspect.getmodule(caller_frame) + package = module.__package__ if module else None - # Get the class from the module - return getattr(module, class_name) + # Import the module using importlib with package context + module = importlib.import_module(module_name, package=package) + + # Get the class from the module and instantiate it + cls = getattr(module, class_name) + return cls(*args, **kwargs) - # Return the import_class function itself, not its result return import_class @@ -69,6 +75,7 @@ OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") +ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") def always_get_an_event_loop() -> asyncio.AbstractEventLoop: @@ -256,6 +263,7 @@ class LightRAG: "NanoVectorDBStorage": NanoVectorDBStorage, "OracleVectorDBStorage": OracleVectorDBStorage, "MilvusVectorDBStorge": MilvusVectorDBStorge, + "ChromaVectorDBStorage": ChromaVectorDBStorage, # graph storage "NetworkXStorage": NetworkXStorage, "Neo4JStorage": Neo4JStorage, diff --git a/test_chromadb.py b/test_chromadb.py new file mode 100644 index 00000000..df721bb2 --- /dev/null +++ b/test_chromadb.py @@ -0,0 +1,113 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm import gpt_4o_mini_complete, openai_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np + +######### +# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() +# import nest_asyncio +# nest_asyncio.apply() +######### +WORKING_DIR = "./chromadb_test_dir" +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# ChromaDB Configuration +CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost") +CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000)) +CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token") +CHROMADB_AUTH_PROVIDER = os.environ.get( + "CHROMADB_AUTH_PROVIDER", "chromadb.auth.token_authn.TokenAuthClientProvider" +) +CHROMADB_AUTH_HEADER = os.environ.get("CHROMADB_AUTH_HEADER", "X-Chroma-Token") + +# Embedding Configuration and Functions +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) + +# ChromaDB requires knowing the dimension of embeddings upfront when +# creating a collection. The embedding dimension is model-specific +# (e.g. text-embedding-3-large uses 3072 dimensions) +# we dynamically determine it by running a test embedding +# and then pass it to the ChromaDBStorage class + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model=EMBEDDING_MODEL, + ) + + +async def get_embedding_dimension(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + return embedding.shape[1] + + +async def create_embedding_function_instance(): + # Get embedding dimension + embedding_dimension = await get_embedding_dimension() + # Create embedding function instance + return EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=EMBEDDING_MAX_TOKEN_SIZE, + func=embedding_func, + ) + + +async def initialize_rag(): + embedding_func_instance = await create_embedding_function_instance() + + return LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, + embedding_func=embedding_func_instance, + vector_storage="ChromaVectorDBStorage", + log_level="DEBUG", + embedding_batch_num=32, + vector_db_storage_cls_kwargs={ + "host": CHROMADB_HOST, + "port": CHROMADB_PORT, + "auth_token": CHROMADB_AUTH_TOKEN, + "auth_provider": CHROMADB_AUTH_PROVIDER, + "auth_header_name": CHROMADB_AUTH_HEADER, + "collection_settings": { + "hnsw:space": "cosine", + "hnsw:construction_ef": 128, + "hnsw:search_ef": 128, + "hnsw:M": 16, + "hnsw:batch_size": 100, + "hnsw:sync_threshold": 1000, + }, + }, + ) + + +# Run the initialization +rag = asyncio.run(initialize_rag()) + +# with open("./dickens/book.txt", "r", encoding="utf-8") as f: +# rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +)