From 893b6455068a70b0716d1db2ee97aa264be2b31c Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Sun, 16 Feb 2025 21:28:58 +0800 Subject: [PATCH] unify doc status retrieval with get_docs_by_status --- lightrag/base.py | 18 ++++------------ lightrag/kg/json_doc_status_impl.py | 32 +++++------------------------ lightrag/kg/mongo_impl.py | 18 +--------------- lightrag/kg/postgres_impl.py | 18 +--------------- lightrag/lightrag.py | 10 ++++----- 5 files changed, 16 insertions(+), 80 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 3d4fc022..d9a63d26 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -249,20 +249,10 @@ class DocStatusStorage(BaseKVStorage): """Get counts of documents in each status""" raise NotImplementedError - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - raise NotImplementedError - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - raise NotImplementedError - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - raise NotImplementedError - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" raise NotImplementedError async def update_doc_status(self, data: dict[str, Any]) -> None: diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index fad03acc..ed79a370 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage): counts[doc["status"]] += 1 return counts - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """all documents with a specific status""" return { k: DocProcessingStatus(**v) for k, v in self._data.items() - if v["status"] == DocStatus.FAILED - } - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PENDING - } - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processed documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSED - } - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSING + if v["status"] == status } async def index_done_callback(self): diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index c216e7be..f6326b76 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -175,7 +175,7 @@ class MongoDocStatusStorage(DocStatusStorage): async def get_docs_by_status( self, status: DocStatus ) -> dict[str, DocProcessingStatus]: - """Get all documents by status""" + """Get all documents with a specific status""" cursor = self._data.find({"status": status.value}) result = await cursor.to_list() return { @@ -191,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage): for doc in result } - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - @dataclass class MongoGraphStorage(BaseGraphStorage): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a44aefe7..51b25385 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_docs_by_status( self, status: DocStatus ) -> Dict[str, DocProcessingStatus]: - """Get all documents by status""" + """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.db.workspace, "status": status} result = await self.db.query(sql, params, True) @@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage): for element in result } - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self): """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" logger.info("Doc status had been saved into postgresql db!") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 23c3df80..9909b4b7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -89,7 +89,7 @@ STORAGE_IMPLEMENTATIONS = { "PGDocStatusStorage", "MongoDocStatusStorage", ], - "required_methods": ["get_pending_docs"], + "required_methods": ["get_docs_by_status"], }, } @@ -230,7 +230,7 @@ class LightRAG: """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" working_dir: str = field( - default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' + default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) """Directory where cache and temporary files are stored.""" @@ -715,11 +715,11 @@ class LightRAG: # 1. Get all pending, failed, and abnormally terminated processing documents. to_process_docs: dict[str, DocProcessingStatus] = {} - processing_docs = await self.doc_status.get_processing_docs() + processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING) to_process_docs.update(processing_docs) - failed_docs = await self.doc_status.get_failed_docs() + failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED) to_process_docs.update(failed_docs) - pendings_docs = await self.doc_status.get_pending_docs() + pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING) to_process_docs.update(pendings_docs) if not to_process_docs: