Merge branch 'main' into graph-viewer-webui

This commit is contained in:
ArnoChen
2025-02-13 04:42:57 +08:00
9 changed files with 191 additions and 6 deletions

View File

@@ -130,7 +130,7 @@ if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database
rag_storage_config.KV_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoDocStatusStorage"
if mongo_graph:
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"

View File

@@ -227,6 +227,14 @@ class DocStatusStorage(BaseKVStorage):
"""Get all pending documents"""
raise NotImplementedError
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
raise NotImplementedError
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:
"""Updates the status of a document. By default, it calls upsert."""
await self.upsert(data)

View File

@@ -16,7 +16,13 @@ from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from ..base import BaseGraphStorage, BaseKVStorage
from ..base import (
BaseGraphStorage,
BaseKVStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
@@ -39,7 +45,8 @@ class MongoKVStorage(BaseKVStorage):
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
@@ -77,6 +84,82 @@ class MongoKVStorage(BaseKVStorage):
await self._data.drop()
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
)
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as doc status {self.namespace}")
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return list(self._data.find({"_id": {"$in": ids}}))
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
for k, v in data.items():
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k
async def drop(self) -> None:
"""Drop the collection"""
await self._data.drop()
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
result = list(self._data.aggregate(pipeline))
counts = {}
for doc in result:
counts[doc["_id"]] = doc["count"]
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents by status"""
result = list(self._data.find({"status": status.value}))
return {
doc["_id"]: DocProcessingStatus(
content=doc["content"],
content_summary=doc.get("content_summary"),
content_length=doc["content_length"],
status=doc["status"],
created_at=doc.get("created_at"),
updated_at=doc.get("updated_at"),
chunks_count=doc.get("chunks_count", -1),
)
for doc in result
}
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""

View File

@@ -495,6 +495,14 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!")

View File

@@ -46,6 +46,7 @@ STORAGES = {
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",