Fix embedding type conversion issue in the query function of chroma_impl; chroma_impl supports local persistent client: PersistentClient
This commit is contained in:
@@ -3,7 +3,7 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
from chromadb import HttpClient
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from lightrag.utils import logger
|
||||
@@ -48,31 +48,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
**user_collection_settings,
|
||||
}
|
||||
|
||||
auth_provider = config.get(
|
||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
auth_credentials = config.get("auth_token", "secret-token")
|
||||
headers = {}
|
||||
local_path = config.get("local_path", None)
|
||||
if local_path:
|
||||
self._client = PersistentClient(
|
||||
path=local_path,
|
||||
settings=Settings(
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
auth_provider = config.get(
|
||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
auth_credentials = config.get("auth_token", "secret-token")
|
||||
headers = {}
|
||||
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
|
||||
self._client = HttpClient(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 8000),
|
||||
headers=headers,
|
||||
settings=Settings(
|
||||
chroma_api_impl="rest",
|
||||
chroma_client_auth_provider=auth_provider,
|
||||
chroma_client_auth_credentials=auth_credentials,
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
self._client = HttpClient(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 8000),
|
||||
headers=headers,
|
||||
settings=Settings(
|
||||
chroma_api_impl="rest",
|
||||
chroma_client_auth_provider=auth_provider,
|
||||
chroma_client_auth_credentials=auth_credentials,
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.namespace,
|
||||
@@ -143,7 +153,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=embedding.tolist(),
|
||||
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
|
||||
n_results=top_k * 2, # Request more results to allow for filtering
|
||||
include=["metadatas", "distances", "documents"],
|
||||
)
|
||||
|
Reference in New Issue
Block a user