Merge pull request #751 from ArnoChenFx/add-MongoDocStatusStorage

add MongoDocStatusStorage
This commit is contained in:
zrguo
2025-02-12 13:39:56 +08:00
committed by GitHub
6 changed files with 103 additions and 5 deletions

View File

@@ -130,7 +130,7 @@ if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database os.environ["MONGO_DATABASE"] = mongo_database
rag_storage_config.KV_STORAGE = "MongoKVStorage" rag_storage_config.KV_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage" rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage"
if mongo_graph: if mongo_graph:
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage" rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"

View File

@@ -227,6 +227,14 @@ class DocStatusStorage(BaseKVStorage):
"""Get all pending documents""" """Get all pending documents"""
raise NotImplementedError raise NotImplementedError
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
raise NotImplementedError
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None: async def update_doc_status(self, data: dict[str, Any]) -> None:
"""Updates the status of a document. By default, it calls upsert.""" """Updates the status of a document. By default, it calls upsert."""
await self.upsert(data) await self.upsert(data)

View File

@@ -16,7 +16,13 @@ from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient from pymongo import MongoClient
from ..base import BaseGraphStorage, BaseKVStorage from ..base import (
BaseGraphStorage,
BaseKVStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
@@ -39,7 +45,8 @@ class MongoKVStorage(BaseKVStorage):
async def filter_keys(self, data: set[str]) -> set[str]: async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [ existing_ids = [
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1}) str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
] ]
return set([s for s in data if s not in existing_ids]) return set([s for s in data if s not in existing_ids])
@@ -77,6 +84,82 @@ class MongoKVStorage(BaseKVStorage):
await self._data.drop() await self._data.drop()
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
)
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as doc status {self.namespace}")
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return list(self._data.find({"_id": {"$in": ids}}))
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
for k, v in data.items():
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k
async def drop(self) -> None:
"""Drop the collection"""
await self._data.drop()
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
result = list(self._data.aggregate(pipeline))
counts = {}
for doc in result:
counts[doc["_id"]] = doc["count"]
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents by status"""
result = list(self._data.find({"status": status.value}))
return {
doc["_id"]: DocProcessingStatus(
content=doc["content"],
content_summary=doc.get("content_summary"),
content_length=doc["content_length"],
status=doc["status"],
created_at=doc.get("created_at"),
updated_at=doc.get("updated_at"),
chunks_count=doc.get("chunks_count", -1),
)
for doc in result
}
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
@dataclass @dataclass
class MongoGraphStorage(BaseGraphStorage): class MongoGraphStorage(BaseGraphStorage):
""" """

View File

@@ -495,6 +495,14 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all pending documents""" """Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING) return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self): async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!") logger.info("Doc status had been saved into postgresql db!")

View File

@@ -70,7 +70,6 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
return [] return []
@@ -123,5 +122,4 @@ class QdrantVectorDBStorage(BaseVectorStorage):
limit=top_k, limit=top_k,
with_payload=True, with_payload=True,
) )
logger.debug(f"query result: {results}")
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]

View File

@@ -46,6 +46,7 @@ STORAGES = {
"OracleVectorDBStorage": ".kg.oracle_impl", "OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl", "MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl", "MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl", "MongoGraphStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl", "RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl", "ChromaVectorDBStorage": ".kg.chroma_impl",