Fix embedding type conversion issue in the query function of chroma_impl; chroma_impl supports local persistent client: PersistentClient

This commit is contained in:
destiny
2025-02-14 11:00:54 +08:00
parent be7a001ad8
commit a5cd2b1958
2 changed files with 82 additions and 48 deletions

View File

@@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# ChromaDB Configuration
CHROMADB_USE_LOCAL_PERSISTENT = False
# Local PersistentClient Configuration
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
# Remote HttpClient Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
@@ -60,30 +64,50 @@ async def create_embedding_function_instance():
async def initialize_rag():
embedding_func_instance = await create_embedding_function_instance()
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
if CHROMADB_USE_LOCAL_PERSISTENT:
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"local_path": CHROMADB_LOCAL_PATH,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
},
)
)
else:
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
)
# Run the initialization

View File

@@ -3,7 +3,7 @@ import asyncio
from dataclasses import dataclass
from typing import Union
import numpy as np
from chromadb import HttpClient
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger
@@ -48,31 +48,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
**user_collection_settings,
}
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
local_path = config.get("local_path", None)
if local_path:
self._client = PersistentClient(
path=local_path,
settings=Settings(
allow_reset=True,
anonymized_telemetry=False,
),
)
else:
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._collection = self._client.get_or_create_collection(
name=self.namespace,
@@ -143,7 +153,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query])
results = self._collection.query(
query_embeddings=embedding.tolist(),
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)