Add optional ids filter to vector database query methods

- Updated query method signatures across multiple vector database implementations
- Added optional `ids` parameter to filter search results
- Consistent implementation across ChromaDB, Faiss, Milvus, MongoDB, NanoVectorDB, Oracle, Qdrant, and TiDB vector storage classes
This commit is contained in:
Roy
2025-03-11 15:22:17 +00:00
parent 92ae895713
commit 8aa9d0e6ca
8 changed files with 24 additions and 8 deletions

View File

@@ -156,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB upsert: {str(e)}") logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise raise
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
try: try:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])

View File

@@ -171,7 +171,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
""" """
Search by a textual query; returns top_k results with their metadata + similarity distance. Search by a textual query; returns top_k results with their metadata + similarity distance.
""" """

View File

@@ -101,7 +101,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
results = self._client.upsert(collection_name=self.namespace, data=list_data) results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results return results
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.namespace,

View File

@@ -938,7 +938,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data return list_data
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search.""" """Queries the vector database using Atlas Vector Search."""
# Generate the embedding # Generate the embedding
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])

View File

@@ -120,7 +120,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
) )
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid long lock times # Execute embedding outside of lock to avoid long lock times
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]

View File

@@ -417,7 +417,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
self.db = None self.db = None
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
# 转换精度 # 转换精度

View File

@@ -123,7 +123,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
return results return results
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.namespace,

View File

@@ -306,7 +306,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""Search from tidb vector""" """Search from tidb vector"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]