From 2bf238396e5af40caa02cb7674b5e761e5e565f7 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 15:52:59 +0100 Subject: [PATCH] updated wrong status --- lightrag/base.py | 23 +++++-------------- lightrag/kg/json_doc_status_impl.py | 34 +++++++---------------------- lightrag/kg/mongo_impl.py | 16 -------------- lightrag/kg/postgres_impl.py | 18 +-------------- lightrag/kg/redis_impl.py | 4 ++-- 5 files changed, 17 insertions(+), 78 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index c44d6af8..98bdb606 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,9 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from enum import StrEnum import os from dataclasses import dataclass, field -from enum import Enum from typing import ( Any, Literal, @@ -203,7 +203,7 @@ class BaseGraphStorage(StorageNameSpace, ABC): """Retrieve a subgraph of the knowledge graph starting from a given node.""" -class DocStatus(str, Enum): +class DocStatus(StrEnum): """Document processing status enum""" PENDING = "pending" @@ -245,18 +245,7 @@ class DocStatusStorage(BaseKVStorage, ABC): """Get counts of documents in each status""" @abstractmethod - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - - @abstractmethod - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - raise NotImplementedError - - @abstractmethod - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - - @abstractmethod - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 7fccf3c3..33df6d43 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -44,33 +44,15 @@ class JsonDocStatusStorage(DocStatusStorage): counts[doc["status"]] += 1 return counts - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.FAILED - } - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PENDING - } - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSED - } - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - return { - k: DocProcessingStatus(**v) - for k, v in self._data.items() - if v["status"] == DocStatus.PROCESSING - } + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == status.value + } async def index_done_callback(self) -> None: write_json(self._data, self._file_name) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 219ec313..abc0aeb5 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -201,22 +201,6 @@ class MongoDocStatusStorage(DocStatusStorage): for doc in result } - async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self) -> None: # Implement the method here pass diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 9bd17ec5..33b4259f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -484,7 +484,7 @@ class PGDocStatusStorage(DocStatusStorage): ) -> Dict[str, DocProcessingStatus]: """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" - params = {"workspace": self.db.workspace, "status": status} + params = {"workspace": self.db.workspace, "status": status.value} result = await self.db.query(sql, params, True) return { element["id"]: DocProcessingStatus( @@ -499,22 +499,6 @@ class PGDocStatusStorage(DocStatusStorage): for element in result } - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: - """Get all processing documents""" - return await self.get_docs_by_status(DocStatus.PROCESSING) - - async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: - """Get all procesed documents""" - return await self.get_docs_by_status(DocStatus.PROCESSED) - async def index_done_callback(self) -> None: pass diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 71e39c5c..98258741 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,5 +1,5 @@ import os -from typing import Any +from typing import Any, final from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -18,7 +18,7 @@ import json config = configparser.ConfigParser() config.read("config.ini", "utf-8") - +@final @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self):