diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 6b521180..35b4cb58 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -156,7 +156,9 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB upsert: {str(e)}") raise - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: try: embedding = await self.embedding_func([query]) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index ab036e6f..6832b756 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -171,7 +171,9 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") return [m["__id__"] for m in list_data] - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: """ Search by a textual query; returns top_k results with their metadata + similarity distance. """ diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index f3a6fcc4..8b82ddf1 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -101,7 +101,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): results = self._client.upsert(collection_name=self.namespace, data=list_data) return results - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index f2ab6ae0..a2d9e51f 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -938,7 +938,9 @@ class MongoVectorDBStorage(BaseVectorStorage): return list_data - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: """Queries the vector database using Atlas Vector Search.""" # Generate the embedding embedding = await self.embedding_func([query]) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 07ccd566..c97aaa3a 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -120,7 +120,9 @@ class NanoVectorDBStorage(BaseVectorStorage): f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" ) - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: # Execute embedding outside of lock to avoid long lock times embedding = await self.embedding_func([query]) embedding = embedding[0] diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index eda3ca63..04552e34 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -417,7 +417,9 @@ class OracleVectorDBStorage(BaseVectorStorage): self.db = None #################### query method ############### - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] # 转换精度 diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 53a59c2f..e32c4335 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -123,7 +123,9 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) return results - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 7af9b48a..9d807798 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -306,7 +306,9 @@ class TiDBVectorDBStorage(BaseVectorStorage): await ClientManager.release_client(self.db) self.db = None - async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: """Search from tidb vector""" embeddings = await self.embedding_func([query]) embedding = embeddings[0]