Fix embedding type conversion issue in the query function of chroma_impl; chroma_impl supports local persistent client: PersistentClient
This commit is contained in:
@@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
# ChromaDB Configuration
|
# ChromaDB Configuration
|
||||||
|
CHROMADB_USE_LOCAL_PERSISTENT = False
|
||||||
|
# Local PersistentClient Configuration
|
||||||
|
CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
|
||||||
|
# Remote HttpClient Configuration
|
||||||
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
|
||||||
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
|
||||||
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
|
||||||
@@ -60,30 +64,50 @@ async def create_embedding_function_instance():
|
|||||||
|
|
||||||
async def initialize_rag():
|
async def initialize_rag():
|
||||||
embedding_func_instance = await create_embedding_function_instance()
|
embedding_func_instance = await create_embedding_function_instance()
|
||||||
|
if CHROMADB_USE_LOCAL_PERSISTENT:
|
||||||
return LightRAG(
|
return LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=gpt_4o_mini_complete,
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
embedding_func=embedding_func_instance,
|
embedding_func=embedding_func_instance,
|
||||||
vector_storage="ChromaVectorDBStorage",
|
vector_storage="ChromaVectorDBStorage",
|
||||||
log_level="DEBUG",
|
log_level="DEBUG",
|
||||||
embedding_batch_num=32,
|
embedding_batch_num=32,
|
||||||
vector_db_storage_cls_kwargs={
|
vector_db_storage_cls_kwargs={
|
||||||
"host": CHROMADB_HOST,
|
"local_path": CHROMADB_LOCAL_PATH,
|
||||||
"port": CHROMADB_PORT,
|
"collection_settings": {
|
||||||
"auth_token": CHROMADB_AUTH_TOKEN,
|
"hnsw:space": "cosine",
|
||||||
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
"hnsw:construction_ef": 128,
|
||||||
"auth_header_name": CHROMADB_AUTH_HEADER,
|
"hnsw:search_ef": 128,
|
||||||
"collection_settings": {
|
"hnsw:M": 16,
|
||||||
"hnsw:space": "cosine",
|
"hnsw:batch_size": 100,
|
||||||
"hnsw:construction_ef": 128,
|
"hnsw:sync_threshold": 1000,
|
||||||
"hnsw:search_ef": 128,
|
},
|
||||||
"hnsw:M": 16,
|
|
||||||
"hnsw:batch_size": 100,
|
|
||||||
"hnsw:sync_threshold": 1000,
|
|
||||||
},
|
},
|
||||||
},
|
)
|
||||||
)
|
else:
|
||||||
|
return LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
|
embedding_func=embedding_func_instance,
|
||||||
|
vector_storage="ChromaVectorDBStorage",
|
||||||
|
log_level="DEBUG",
|
||||||
|
embedding_batch_num=32,
|
||||||
|
vector_db_storage_cls_kwargs={
|
||||||
|
"host": CHROMADB_HOST,
|
||||||
|
"port": CHROMADB_PORT,
|
||||||
|
"auth_token": CHROMADB_AUTH_TOKEN,
|
||||||
|
"auth_provider": CHROMADB_AUTH_PROVIDER,
|
||||||
|
"auth_header_name": CHROMADB_AUTH_HEADER,
|
||||||
|
"collection_settings": {
|
||||||
|
"hnsw:space": "cosine",
|
||||||
|
"hnsw:construction_ef": 128,
|
||||||
|
"hnsw:search_ef": 128,
|
||||||
|
"hnsw:M": 16,
|
||||||
|
"hnsw:batch_size": 100,
|
||||||
|
"hnsw:sync_threshold": 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Run the initialization
|
# Run the initialization
|
||||||
|
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from chromadb import HttpClient
|
from chromadb import HttpClient, PersistentClient
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
@@ -48,31 +48,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
**user_collection_settings,
|
**user_collection_settings,
|
||||||
}
|
}
|
||||||
|
|
||||||
auth_provider = config.get(
|
local_path = config.get("local_path", None)
|
||||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
if local_path:
|
||||||
)
|
self._client = PersistentClient(
|
||||||
auth_credentials = config.get("auth_token", "secret-token")
|
path=local_path,
|
||||||
headers = {}
|
settings=Settings(
|
||||||
|
allow_reset=True,
|
||||||
|
anonymized_telemetry=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
auth_provider = config.get(
|
||||||
|
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||||
|
)
|
||||||
|
auth_credentials = config.get("auth_token", "secret-token")
|
||||||
|
headers = {}
|
||||||
|
|
||||||
if "token_authn" in auth_provider:
|
if "token_authn" in auth_provider:
|
||||||
headers = {
|
headers = {
|
||||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||||
}
|
}
|
||||||
elif "basic_authn" in auth_provider:
|
elif "basic_authn" in auth_provider:
|
||||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||||
|
|
||||||
self._client = HttpClient(
|
self._client = HttpClient(
|
||||||
host=config.get("host", "localhost"),
|
host=config.get("host", "localhost"),
|
||||||
port=config.get("port", 8000),
|
port=config.get("port", 8000),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
settings=Settings(
|
settings=Settings(
|
||||||
chroma_api_impl="rest",
|
chroma_api_impl="rest",
|
||||||
chroma_client_auth_provider=auth_provider,
|
chroma_client_auth_provider=auth_provider,
|
||||||
chroma_client_auth_credentials=auth_credentials,
|
chroma_client_auth_credentials=auth_credentials,
|
||||||
allow_reset=True,
|
allow_reset=True,
|
||||||
anonymized_telemetry=False,
|
anonymized_telemetry=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._collection = self._client.get_or_create_collection(
|
self._collection = self._client.get_or_create_collection(
|
||||||
name=self.namespace,
|
name=self.namespace,
|
||||||
@@ -143,7 +153,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
|
|
||||||
results = self._collection.query(
|
results = self._collection.query(
|
||||||
query_embeddings=embedding.tolist(),
|
query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
|
||||||
n_results=top_k * 2, # Request more results to allow for filtering
|
n_results=top_k * 2, # Request more results to allow for filtering
|
||||||
include=["metadatas", "distances", "documents"],
|
include=["metadatas", "distances", "documents"],
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user