Merge pull request #772 from hhxdestiny/chroma_support_local

chroma_impl supports local persistent client: PersistentClient
This commit is contained in:
zrguo
2025-02-16 19:21:53 +08:00
committed by GitHub
2 changed files with 82 additions and 48 deletions

View File

@@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# ChromaDB Configuration
CHROMADB_USE_LOCAL_PERSISTENT = False
# Local PersistentClient Configuration
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
# Remote HttpClient Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
@@ -60,30 +64,50 @@ async def create_embedding_function_instance():
async def initialize_rag():
embedding_func_instance = await create_embedding_function_instance()
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
if CHROMADB_USE_LOCAL_PERSISTENT:
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"local_path": CHROMADB_LOCAL_PATH,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
},
)
)
else:
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
)
# Run the initialization

View File

@@ -2,7 +2,7 @@ import asyncio
from dataclasses import dataclass
from typing import Union
import numpy as np
from chromadb import HttpClient
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger
@@ -49,31 +49,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
**user_collection_settings,
}
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
local_path = config.get("local_path", None)
if local_path:
self._client = PersistentClient(
path=local_path,
settings=Settings(
allow_reset=True,
anonymized_telemetry=False,
),
)
else:
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._collection = self._client.get_or_create_collection(
name=self.namespace,
@@ -144,7 +154,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query])
results = self._collection.query(
query_embeddings=embedding.tolist(),
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)