diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8d13fab0..529b2bfb 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -130,7 +130,7 @@ if mongo_uri: os.environ["MONGO_URI"] = mongo_uri os.environ["MONGO_DATABASE"] = mongo_database rag_storage_config.KV_STORAGE = "MongoKVStorage" - rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage" + rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage" if mongo_graph: rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage" diff --git a/lightrag/base.py b/lightrag/base.py index bd79d990..3702b49e 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -227,6 +227,14 @@ class DocStatusStorage(BaseKVStorage): """Get all pending documents""" raise NotImplementedError + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processing documents""" + raise NotImplementedError + + async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all procesed documents""" + raise NotImplementedError + async def update_doc_status(self, data: dict[str, Any]) -> None: """Updates the status of a document. By default, it calls upsert.""" await self.upsert(data) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 4f919ecd..8662d005 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -16,7 +16,13 @@ from typing import Any, List, Tuple, Union from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient -from ..base import BaseGraphStorage, BaseKVStorage +from ..base import ( + BaseGraphStorage, + BaseKVStorage, + DocProcessingStatus, + DocStatus, + DocStatusStorage, +) from ..namespace import NameSpace, is_namespace from ..utils import logger @@ -39,7 +45,8 @@ class MongoKVStorage(BaseKVStorage): async def filter_keys(self, data: set[str]) -> set[str]: existing_ids = [ - str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1}) + str(x["_id"]) + for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) ] return set([s for s in data if s not in existing_ids]) @@ -77,6 +84,82 @@ class MongoKVStorage(BaseKVStorage): await self._data.drop() +@dataclass +class MongoDocStatusStorage(DocStatusStorage): + def __post_init__(self): + client = MongoClient( + os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + ) + database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) + self._data = database.get_collection(self.namespace) + logger.info(f"Use MongoDB as doc status {self.namespace}") + + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + return self._data.find_one({"_id": id}) + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + return list(self._data.find({"_id": {"$in": ids}})) + + async def filter_keys(self, data: set[str]) -> set[str]: + existing_ids = [ + str(x["_id"]) + for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + ] + return set([s for s in data if s not in existing_ids]) + + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + for k, v in data.items(): + self._data.update_one({"_id": k}, {"$set": v}, upsert=True) + data[k]["_id"] = k + + async def drop(self) -> None: + """Drop the collection""" + await self._data.drop() + + async def get_status_counts(self) -> dict[str, int]: + """Get counts of documents in each status""" + pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}] + result = list(self._data.aggregate(pipeline)) + counts = {} + for doc in result: + counts[doc["_id"]] = doc["count"] + return counts + + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents by status""" + result = list(self._data.find({"status": status.value})) + return { + doc["_id"]: DocProcessingStatus( + content=doc["content"], + content_summary=doc.get("content_summary"), + content_length=doc["content_length"], + status=doc["status"], + created_at=doc.get("created_at"), + updated_at=doc.get("updated_at"), + chunks_count=doc.get("chunks_count", -1), + ) + for doc in result + } + + async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all failed documents""" + return await self.get_docs_by_status(DocStatus.FAILED) + + async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: + """Get all pending documents""" + return await self.get_docs_by_status(DocStatus.PENDING) + + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processing documents""" + return await self.get_docs_by_status(DocStatus.PROCESSING) + + async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all procesed documents""" + return await self.get_docs_by_status(DocStatus.PROCESSED) + + @dataclass class MongoGraphStorage(BaseGraphStorage): """ diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 79496abd..00ab9f0b 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -495,6 +495,14 @@ class PGDocStatusStorage(DocStatusStorage): """Get all pending documents""" return await self.get_docs_by_status(DocStatus.PENDING) + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all processing documents""" + return await self.get_docs_by_status(DocStatus.PROCESSING) + + async def get_processed_docs(self) -> dict[str, DocProcessingStatus]: + """Get all procesed documents""" + return await self.get_docs_by_status(DocStatus.PROCESSED) + async def index_done_callback(self): """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" logger.info("Doc status had been saved into postgresql db!") diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index e2f8d3a2..8c7a7029 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -70,7 +70,6 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) async def upsert(self, data: dict[str, dict]): - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") if not len(data): logger.warning("You insert an empty data to vector DB") return [] @@ -123,5 +122,4 @@ class QdrantVectorDBStorage(BaseVectorStorage): limit=top_k, with_payload=True, ) - logger.debug(f"query result: {results}") return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 070ac11d..e25f4879 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -46,6 +46,7 @@ STORAGES = { "OracleVectorDBStorage": ".kg.oracle_impl", "MilvusVectorDBStorge": ".kg.milvus_impl", "MongoKVStorage": ".kg.mongo_impl", + "MongoDocStatusStorage": ".kg.mongo_impl", "MongoGraphStorage": ".kg.mongo_impl", "RedisKVStorage": ".kg.redis_impl", "ChromaVectorDBStorage": ".kg.chroma_impl",