Fix embedding type conversion issue in the query function of chroma_impl; chroma_impl supports local persistent client: PersistentClient

This commit is contained in:
destiny
2025-02-14 11:00:54 +08:00
parent be7a001ad8
commit a5cd2b1958
2 changed files with 82 additions and 48 deletions

View File

@@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
# ChromaDB Configuration # ChromaDB Configuration
CHROMADB_USE_LOCAL_PERSISTENT = False
# Local PersistentClient Configuration
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
# Remote HttpClient Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost") CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000)) CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token") CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
@@ -60,30 +64,50 @@ async def create_embedding_function_instance():
async def initialize_rag(): async def initialize_rag():
embedding_func_instance = await create_embedding_function_instance() embedding_func_instance = await create_embedding_function_instance()
if CHROMADB_USE_LOCAL_PERSISTENT:
return LightRAG( return LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance, embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage", vector_storage="ChromaVectorDBStorage",
log_level="DEBUG", log_level="DEBUG",
embedding_batch_num=32, embedding_batch_num=32,
vector_db_storage_cls_kwargs={ vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST, "local_path": CHROMADB_LOCAL_PATH,
"port": CHROMADB_PORT, "collection_settings": {
"auth_token": CHROMADB_AUTH_TOKEN, "hnsw:space": "cosine",
"auth_provider": CHROMADB_AUTH_PROVIDER, "hnsw:construction_ef": 128,
"auth_header_name": CHROMADB_AUTH_HEADER, "hnsw:search_ef": 128,
"collection_settings": { "hnsw:M": 16,
"hnsw:space": "cosine", "hnsw:batch_size": 100,
"hnsw:construction_ef": 128, "hnsw:sync_threshold": 1000,
"hnsw:search_ef": 128, },
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
}, },
}, )
) else:
return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
)
# Run the initialization # Run the initialization

View File

@@ -3,7 +3,7 @@ import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Union
import numpy as np import numpy as np
from chromadb import HttpClient from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings from chromadb.config import Settings
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from lightrag.utils import logger from lightrag.utils import logger
@@ -48,31 +48,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
**user_collection_settings, **user_collection_settings,
} }
auth_provider = config.get( local_path = config.get("local_path", None)
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider" if local_path:
) self._client = PersistentClient(
auth_credentials = config.get("auth_token", "secret-token") path=local_path,
headers = {} settings=Settings(
allow_reset=True,
anonymized_telemetry=False,
),
)
else:
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
if "token_authn" in auth_provider: if "token_authn" in auth_provider:
headers = { headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
} }
elif "basic_authn" in auth_provider: elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin") auth_credentials = config.get("auth_credentials", "admin:admin")
self._client = HttpClient( self._client = HttpClient(
host=config.get("host", "localhost"), host=config.get("host", "localhost"),
port=config.get("port", 8000), port=config.get("port", 8000),
headers=headers, headers=headers,
settings=Settings( settings=Settings(
chroma_api_impl="rest", chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider, chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials, chroma_client_auth_credentials=auth_credentials,
allow_reset=True, allow_reset=True,
anonymized_telemetry=False, anonymized_telemetry=False,
), ),
) )
self._collection = self._client.get_or_create_collection( self._collection = self._client.get_or_create_collection(
name=self.namespace, name=self.namespace,
@@ -143,7 +153,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
results = self._collection.query( results = self._collection.query(
query_embeddings=embedding.tolist(), query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
n_results=top_k * 2, # Request more results to allow for filtering n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"], include=["metadatas", "distances", "documents"],
) )