Fix embedding type conversion issue in the query function of chroma_impl; chroma_impl supports local persistent client: PersistentClient
This commit is contained in:
@@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
# ChromaDB Configuration
|
||||
CHROMADB_USE_LOCAL_PERSISTENT = False
|
||||
# Local PersistentClient Configuration
|
||||
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
|
||||
# Remote HttpClient Configuration
|
||||
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
||||
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
||||
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
||||
@@ -60,30 +64,50 @@ async def create_embedding_function_instance():
|
||||
|
||||
async def initialize_rag():
|
||||
embedding_func_instance = await create_embedding_function_instance()
|
||||
|
||||
return LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
embedding_func=embedding_func_instance,
|
||||
vector_storage="ChromaVectorDBStorage",
|
||||
log_level="DEBUG",
|
||||
embedding_batch_num=32,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"host": CHROMADB_HOST,
|
||||
"port": CHROMADB_PORT,
|
||||
"auth_token": CHROMADB_AUTH_TOKEN,
|
||||
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
||||
"auth_header_name": CHROMADB_AUTH_HEADER,
|
||||
"collection_settings": {
|
||||
"hnsw:space": "cosine",
|
||||
"hnsw:construction_ef": 128,
|
||||
"hnsw:search_ef": 128,
|
||||
"hnsw:M": 16,
|
||||
"hnsw:batch_size": 100,
|
||||
"hnsw:sync_threshold": 1000,
|
||||
if CHROMADB_USE_LOCAL_PERSISTENT:
|
||||
return LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
embedding_func=embedding_func_instance,
|
||||
vector_storage="ChromaVectorDBStorage",
|
||||
log_level="DEBUG",
|
||||
embedding_batch_num=32,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"local_path": CHROMADB_LOCAL_PATH,
|
||||
"collection_settings": {
|
||||
"hnsw:space": "cosine",
|
||||
"hnsw:construction_ef": 128,
|
||||
"hnsw:search_ef": 128,
|
||||
"hnsw:M": 16,
|
||||
"hnsw:batch_size": 100,
|
||||
"hnsw:sync_threshold": 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
return LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
embedding_func=embedding_func_instance,
|
||||
vector_storage="ChromaVectorDBStorage",
|
||||
log_level="DEBUG",
|
||||
embedding_batch_num=32,
|
||||
vector_db_storage_cls_kwargs={
|
||||
"host": CHROMADB_HOST,
|
||||
"port": CHROMADB_PORT,
|
||||
"auth_token": CHROMADB_AUTH_TOKEN,
|
||||
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
||||
"auth_header_name": CHROMADB_AUTH_HEADER,
|
||||
"collection_settings": {
|
||||
"hnsw:space": "cosine",
|
||||
"hnsw:construction_ef": 128,
|
||||
"hnsw:search_ef": 128,
|
||||
"hnsw:M": 16,
|
||||
"hnsw:batch_size": 100,
|
||||
"hnsw:sync_threshold": 1000,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Run the initialization
|
||||
|
@@ -3,7 +3,7 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
from chromadb import HttpClient
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from lightrag.utils import logger
|
||||
@@ -48,31 +48,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
**user_collection_settings,
|
||||
}
|
||||
|
||||
auth_provider = config.get(
|
||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
auth_credentials = config.get("auth_token", "secret-token")
|
||||
headers = {}
|
||||
local_path = config.get("local_path", None)
|
||||
if local_path:
|
||||
self._client = PersistentClient(
|
||||
path=local_path,
|
||||
settings=Settings(
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
auth_provider = config.get(
|
||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
auth_credentials = config.get("auth_token", "secret-token")
|
||||
headers = {}
|
||||
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
|
||||
self._client = HttpClient(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 8000),
|
||||
headers=headers,
|
||||
settings=Settings(
|
||||
chroma_api_impl="rest",
|
||||
chroma_client_auth_provider=auth_provider,
|
||||
chroma_client_auth_credentials=auth_credentials,
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
self._client = HttpClient(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 8000),
|
||||
headers=headers,
|
||||
settings=Settings(
|
||||
chroma_api_impl="rest",
|
||||
chroma_client_auth_provider=auth_provider,
|
||||
chroma_client_auth_credentials=auth_credentials,
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.namespace,
|
||||
@@ -143,7 +153,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=embedding.tolist(),
|
||||
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
|
||||
n_results=top_k * 2, # Request more results to allow for filtering
|
||||
include=["metadatas", "distances", "documents"],
|
||||
)
|
||||
|
Reference in New Issue
Block a user