diff --git a/lightrag/base.py b/lightrag/base.py index 4b963b43..60b9b3f1 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -84,9 +84,6 @@ class BaseVectorStorage(StorageNameSpace): class BaseKVStorage(StorageNameSpace): embedding_func: EmbeddingFunc - async def all_keys(self) -> list[str]: - raise NotImplementedError - async def get_by_id(self, id: str) -> dict[str, Any]: raise NotImplementedError @@ -103,9 +100,6 @@ class BaseKVStorage(StorageNameSpace): async def drop(self) -> None: raise NotImplementedError - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - raise NotImplementedError - @dataclass class BaseGraphStorage(StorageNameSpace): diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index e9225375..14565c86 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,7 +1,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Any, Union +from typing import Any from lightrag.utils import ( logger, @@ -21,10 +21,7 @@ class JsonKVStorage(BaseKVStorage): self._data: dict[str, Any] = load_json(self._file_name) or {} self._lock = asyncio.Lock() logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - - async def all_keys(self) -> list[str]: - return list(self._data.keys()) - + async def index_done_callback(self): write_json(self._data, self._file_name) @@ -49,8 +46,4 @@ class JsonKVStorage(BaseKVStorage): self._data.update(left_data) async def drop(self) -> None: - self._data = {} - - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - result = [v for _, v in self._data.items() if v["status"] == status] - return result if result else None + self._data = {} \ No newline at end of file diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index b7b438bd..45d4bb07 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -29,9 +29,6 @@ class MongoKVStorage(BaseKVStorage): self._data = database.get_collection(self.namespace) logger.info(f"Use MongoDB as KV {self.namespace}") - async def all_keys(self) -> list[str]: - return [x["_id"] for x in self._data.find({}, {"_id": 1})] - async def get_by_id(self, id: str) -> dict[str, Any]: return self._data.find_one({"_id": id}) @@ -77,11 +74,6 @@ class MongoKVStorage(BaseKVStorage): """Drop the collection""" await self._data.drop() - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - """Get documents by status and ids""" - return self._data.find({"status": status}) - - @dataclass class MongoGraphStorage(BaseGraphStorage): """ diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index c82db9a6..b648c9bc 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -229,12 +229,6 @@ class OracleKVStorage(BaseKVStorage): res = [{k: v} for k, v in dict_res.items()] return res - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - """Specifically for llm_response_cache.""" - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] - params = {"workspace": self.db.workspace, "status": status} - return await self.db.query(SQL, params, multirows=True) - async def filter_keys(self, keys: list[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 01e3688a..b37f8434 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -237,16 +237,6 @@ class PGKVStorage(BaseKVStorage): params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) - async def all_keys(self) -> list[dict]: - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - sql = "select workspace,mode,id from lightrag_llm_cache" - res = await self.db.query(sql, multirows=True) - return res - else: - logger.error( - f"all_keys is only implemented for llm_response_cache, not for {self.namespace}" - ) - async def filter_keys(self, keys: List[str]) -> Set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index f9283dda..025f293f 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,5 +1,5 @@ import os -from typing import Any, Union +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -20,11 +20,7 @@ class RedisKVStorage(BaseKVStorage): redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379") self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") - - async def all_keys(self) -> list[str]: - keys = await self._redis.keys(f"{self.namespace}:*") - return [key.split(":", 1)[-1] for key in keys] - + async def get_by_id(self, id): data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None @@ -57,11 +53,4 @@ class RedisKVStorage(BaseKVStorage): async def drop(self) -> None: keys = await self._redis.keys(f"{self.namespace}:*") if keys: - await self._redis.delete(*keys) - - async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: - pipe = self._redis.pipeline() - for key in await self._redis.keys(f"{self.namespace}:*"): - pipe.hgetall(key) - results = await pipe.execute() - return [data for data in results if data.get("status") == status] or None + await self._redis.delete(*keys) \ No newline at end of file diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index aaae68c9..00174fcd 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -29,6 +29,7 @@ from .base import ( BaseKVStorage, BaseVectorStorage, DocStatus, + DocStatusStorage, QueryParam, StorageNameSpace, ) @@ -319,7 +320,7 @@ class LightRAG: # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - self.doc_status: BaseKVStorage = self.doc_status_storage_cls( + self.doc_status: DocStatusStorage = self.doc_status_storage_cls( namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), global_config=global_config, embedding_func=None, @@ -394,10 +395,8 @@ class LightRAG: split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. """ - await self.apipeline_process_documents(string_or_strings) - await self.apipeline_process_enqueue_documents( - split_by_character, split_by_character_only - ) + await self.apipeline_enqueue_documents(string_or_strings) + await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only) def insert_custom_chunks(self, full_text: str, text_chunks: list[str]): loop = always_get_an_event_loop() @@ -496,8 +495,13 @@ class LightRAG: # 3. Filter out already processed documents add_doc_keys: set[str] = set() - excluded_ids = await self.doc_status.all_keys() + # Get docs ids + in_process_keys = list(new_docs.keys()) + # Get in progress docs ids + excluded_ids = await self.doc_status.get_by_ids(in_process_keys) + # Exclude already in process add_doc_keys = new_docs.keys() - excluded_ids + # Filter new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys} if not new_docs: @@ -513,12 +517,12 @@ class LightRAG: to_process_doc_keys: list[str] = [] # Fetch failed documents - failed_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED) + failed_docs = await self.doc_status.get_failed_docs() if failed_docs: to_process_doc_keys.extend([doc["id"] for doc in failed_docs]) # Fetch pending documents - pending_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING) + pending_docs = await self.doc_status.get_pending_docs() if pending_docs: to_process_doc_keys.extend([doc["id"] for doc in pending_docs])