Add optional ids filter to vector database query methods
- Updated query method signatures across multiple vector database implementations - Added optional `ids` parameter to filter search results - Consistent implementation across ChromaDB, Faiss, Milvus, MongoDB, NanoVectorDB, Oracle, Qdrant, and TiDB vector storage classes
This commit is contained in:
@@ -156,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(
|
||||||
|
self, query: str, top_k: int, ids: list[str] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
|
|
||||||
|
@@ -171,7 +171,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||||
return [m["__id__"] for m in list_data]
|
return [m["__id__"] for m in list_data]
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(
|
||||||
|
self, query: str, top_k: int, ids: list[str] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||||
"""
|
"""
|
||||||
|
@@ -101,7 +101,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
|||||||
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(
|
||||||
|
self, query: str, top_k: int, ids: list[str] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
|
@@ -938,7 +938,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
return list_data
|
return list_data
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(
|
||||||
|
self, query: str, top_k: int, ids: list[str] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""Queries the vector database using Atlas Vector Search."""
|
"""Queries the vector database using Atlas Vector Search."""
|
||||||
# Generate the embedding
|
# Generate the embedding
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
|
@@ -120,7 +120,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(
|
||||||
|
self, query: str, top_k: int, ids: list[str] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
# Execute embedding outside of lock to avoid long lock times
|
# Execute embedding outside of lock to avoid long lock times
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
embedding = embedding[0]
|
embedding = embedding[0]
|
||||||
|
@@ -417,7 +417,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
#################### query method ###############
|
#################### query method ###############
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(
|
||||||
|
self, query: str, top_k: int, ids: list[str] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
# 转换精度
|
# 转换精度
|
||||||
|
@@ -123,7 +123,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(
|
||||||
|
self, query: str, top_k: int, ids: list[str] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self.namespace,
|
collection_name=self.namespace,
|
||||||
|
@@ -306,7 +306,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
async def query(
|
||||||
|
self, query: str, top_k: int, ids: list[str] | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""Search from tidb vector"""
|
"""Search from tidb vector"""
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
|
Reference in New Issue
Block a user