unify doc status retrieval with get_docs_by_status

This commit is contained in:
ArnoChen
2025-02-16 21:28:58 +08:00
parent b580e47324
commit 893b645506
5 changed files with 16 additions and 80 deletions

View File

@@ -249,20 +249,10 @@ class DocStatusStorage(BaseKVStorage):
"""Get counts of documents in each status""" """Get counts of documents in each status"""
raise NotImplementedError raise NotImplementedError
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: async def get_docs_by_status(
"""Get all failed documents""" self, status: DocStatus
raise NotImplementedError ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
raise NotImplementedError
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
raise NotImplementedError raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None: async def update_doc_status(self, data: dict[str, Any]) -> None:

View File

@@ -93,36 +93,14 @@ class JsonDocStatusStorage(DocStatusStorage):
counts[doc["status"]] += 1 counts[doc["status"]] += 1
return counts return counts
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: async def get_docs_by_status(
"""Get all failed documents""" self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""all documents with a specific status"""
return { return {
k: DocProcessingStatus(**v) k: DocProcessingStatus(**v)
for k, v in self._data.items() for k, v in self._data.items()
if v["status"] == DocStatus.FAILED if v["status"] == status
}
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PENDING
}
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processed documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSED
}
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return {
k: DocProcessingStatus(**v)
for k, v in self._data.items()
if v["status"] == DocStatus.PROCESSING
} }
async def index_done_callback(self): async def index_done_callback(self):

View File

@@ -175,7 +175,7 @@ class MongoDocStatusStorage(DocStatusStorage):
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents by status""" """Get all documents with a specific status"""
cursor = self._data.find({"status": status.value}) cursor = self._data.find({"status": status.value})
result = await cursor.to_list() result = await cursor.to_list()
return { return {
@@ -191,22 +191,6 @@ class MongoDocStatusStorage(DocStatusStorage):
for doc in result for doc in result
} }
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
@dataclass @dataclass
class MongoGraphStorage(BaseGraphStorage): class MongoGraphStorage(BaseGraphStorage):

View File

@@ -468,7 +468,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_docs_by_status( async def get_docs_by_status(
self, status: DocStatus self, status: DocStatus
) -> Dict[str, DocProcessingStatus]: ) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status""" """all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
result = await self.db.query(sql, params, True) result = await self.db.query(sql, params, True)
@@ -485,22 +485,6 @@ class PGDocStatusStorage(DocStatusStorage):
for element in result for element in result
} }
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self): async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!") logger.info("Doc status had been saved into postgresql db!")

View File

@@ -89,7 +89,7 @@ STORAGE_IMPLEMENTATIONS = {
"PGDocStatusStorage", "PGDocStatusStorage",
"MongoDocStatusStorage", "MongoDocStatusStorage",
], ],
"required_methods": ["get_pending_docs"], "required_methods": ["get_docs_by_status"],
}, },
} }
@@ -230,7 +230,7 @@ class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation.""" """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
"""Directory where cache and temporary files are stored.""" """Directory where cache and temporary files are stored."""
@@ -715,11 +715,11 @@ class LightRAG:
# 1. Get all pending, failed, and abnormally terminated processing documents. # 1. Get all pending, failed, and abnormally terminated processing documents.
to_process_docs: dict[str, DocProcessingStatus] = {} to_process_docs: dict[str, DocProcessingStatus] = {}
processing_docs = await self.doc_status.get_processing_docs() processing_docs = await self.doc_status.get_docs_by_status(DocStatus.PROCESSING)
to_process_docs.update(processing_docs) to_process_docs.update(processing_docs)
failed_docs = await self.doc_status.get_failed_docs() failed_docs = await self.doc_status.get_docs_by_status(DocStatus.FAILED)
to_process_docs.update(failed_docs) to_process_docs.update(failed_docs)
pendings_docs = await self.doc_status.get_pending_docs() pendings_docs = await self.doc_status.get_docs_by_status(DocStatus.PENDING)
to_process_docs.update(pendings_docs) to_process_docs.update(pendings_docs)
if not to_process_docs: if not to_process_docs: