diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index cb3b59f1..e32346f9 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,6 +1,6 @@ import asyncio from dataclasses import dataclass -from typing import Union +from typing import Any import numpy as np from chromadb import HttpClient, PersistentClient from chromadb.config import Settings @@ -102,7 +102,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"ChromaDB initialization failed: {str(e)}") raise - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: logger.warning("Empty data provided to vector DB") return [] @@ -151,7 +151,7 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB upsert: {str(e)}") raise - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: try: embedding = await self.embedding_func([query]) @@ -183,6 +183,15 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB query: {str(e)}") raise - async def index_done_callback(self): + + async def index_done_callback(self) -> None: # ChromaDB handles persistence automatically pass + + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 9a5f7e4e..3027f3f0 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -1,6 +1,7 @@ import os import time import asyncio +from typing import Any import faiss import json import numpy as np @@ -57,7 +58,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Attempt to load an existing index + metadata from disk self._load_faiss_index() - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ Insert or update vectors in the Faiss index. @@ -147,7 +148,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") return [m["__id__"] for m in list_data] - async def query(self, query: str, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """ Search by a textual query; returns top_k results with their metadata + similarity distance. """ @@ -210,7 +211,7 @@ class FaissVectorDBStorage(BaseVectorStorage): f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" ) - async def delete_entity(self, entity_name: str): + async def delete_entity(self, entity_name: str) -> None: """ Delete a single entity by computing its hashed ID the same way your code does it with `compute_mdhash_id`. @@ -234,7 +235,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._remove_faiss_ids(relations) logger.debug(f"Deleted {len(relations)} relations for {entity_name}") - async def index_done_callback(self): + async def index_done_callback(self) -> None: """ Called after indexing is done (save Faiss index + metadata). """ diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index f4d9d47f..d67f03b1 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1,5 +1,6 @@ import asyncio import os +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -71,7 +72,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): dimension=self.embedding_func.embedding_dim, ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -106,7 +107,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): results = self._client.upsert(collection_name=self.namespace, data=list_data) return results - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, @@ -123,3 +124,14 @@ class MilvusVectorDBStorage(BaseVectorStorage): {**dp["entity"], "id": dp["id"], "distance": dp["distance"]} for dp in results[0] ] + + async def index_done_callback(self) -> None: + pass + + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index c216e7be..39bb9f18 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -844,7 +844,7 @@ class MongoVectorDBStorage(BaseVectorStorage): except PyMongoError as _: logger.debug("vector index already exist") - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") if not data: logger.warning("You are inserting an empty data set to vector DB") @@ -887,7 +887,7 @@ class MongoVectorDBStorage(BaseVectorStorage): return list_data - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Queries the vector database using Atlas Vector Search.""" # Generate the embedding embedding = await self.embedding_func([query]) @@ -921,6 +921,16 @@ class MongoVectorDBStorage(BaseVectorStorage): for doc in results ] + async def index_done_callback(self) -> None: + pass + + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): """Check if the collection exists. if not, create it.""" diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 5d786646..8b931424 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -50,6 +50,7 @@ Usage: import asyncio import os +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -95,7 +96,7 @@ class NanoVectorDBStorage(BaseVectorStorage): self.embedding_func.embedding_dim, storage_file=self._client_file_name ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -139,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage): f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" ) - async def query(self, query: str, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) embedding = embedding[0] results = self._client.query( @@ -176,7 +177,7 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - async def delete_entity(self, entity_name: str): + async def delete_entity(self, entity_name: str) -> None: try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( @@ -211,7 +212,7 @@ class NanoVectorDBStorage(BaseVectorStorage): except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") - async def index_done_callback(self): + async def index_done_callback(self) -> None: # Protect file write operation async with self._save_lock: self._client.save() diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 65f1060c..197d101e 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -307,7 +307,7 @@ class OracleKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) - async def index_done_callback(self): + async def index_done_callback(self) -> None: if is_namespace( self.namespace, (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), @@ -330,16 +330,14 @@ class OracleVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold - async def upsert(self, data: dict[str, dict]): - """向向量数据库中插入数据""" + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pass async def index_done_callback(self): pass #################### query method ############### - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] # 转换精度 @@ -359,6 +357,13 @@ class OracleVectorDBStorage(BaseVectorStorage): # print("vector search result:",results) return results + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError @dataclass class OracleGraphStorage(BaseGraphStorage): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a44aefe7..5dbc6a8e 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -287,7 +287,7 @@ class PGKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) - async def index_done_callback(self): + async def index_done_callback(self) -> None: if is_namespace( self.namespace, (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), @@ -352,7 +352,7 @@ class PGVectorStorage(BaseVectorStorage): } return upsert_sql, data - async def upsert(self, data: Dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") @@ -398,12 +398,11 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) - async def index_done_callback(self): + async def index_done_callback(self) -> None: logger.info("vector data had been saved into postgresql db!") #################### query method ############### - async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: - """从向量数据库中查询数据""" + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) @@ -417,6 +416,13 @@ class PGVectorStorage(BaseVectorStorage): results = await self.db.query(sql, params=params, multirows=True) return results + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError @dataclass class PGDocStatusStorage(DocStatusStorage): diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 3af76328..18a50082 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -1,5 +1,6 @@ import asyncio import os +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import numpy as np @@ -85,7 +86,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ), ) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not len(data): logger.warning("You insert an empty data to vector DB") return [] @@ -130,7 +131,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) return results - async def query(self, query, top_k=5): + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embedding = await self.embedding_func([query]) results = self._client.search( collection_name=self.namespace, @@ -143,3 +144,14 @@ class QdrantVectorDBStorage(BaseVectorStorage): logger.debug(f"query result: {results}") return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] + + async def index_done_callback(self) -> None: + pass + + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 00b8003d..a5a5c80d 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -227,7 +227,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold - async def query(self, query: str, top_k: int) -> list[dict]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Search from tidb vector""" embeddings = await self.embedding_func([query]) embedding = embeddings[0] @@ -249,7 +249,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): return results ###### INSERT entities And relationships ###### - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: # ignore, upsert in TiDBKVStorage already if not len(data): logger.warning("You insert an empty data to vector DB") @@ -333,6 +333,14 @@ class TiDBVectorDBStorage(BaseVectorStorage): return await self.db.query(SQL, params, multirows=True) + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name""" + raise NotImplementedError + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity by scanning metadata""" + raise NotImplementedError + @dataclass class TiDBGraphStorage(BaseGraphStorage): # db instance must be injected before use