Merge branch 'main' into select-datastore-in-api-server

This commit is contained in:
yangdx
2025-02-13 11:25:52 +08:00
9 changed files with 187 additions and 9 deletions

View File

@@ -228,7 +228,11 @@ class DocStatusStorage(BaseKVStorage):
raise NotImplementedError
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all documents that are currently being processed"""
"""Get all processing documents"""
raise NotImplementedError
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
raise NotImplementedError
async def update_doc_status(self, data: dict[str, Any]) -> None:

View File

@@ -14,7 +14,14 @@ if not pm.is_installed("motor"):
from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from ..base import BaseGraphStorage, BaseKVStorage
from ..base import (
BaseGraphStorage,
BaseKVStorage,
DocProcessingStatus,
DocStatus,
DocStatusStorage,
)
from ..namespace import NameSpace, is_namespace
from ..utils import logger
@@ -51,7 +58,8 @@ class MongoKVStorage(BaseKVStorage):
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
@@ -89,6 +97,82 @@ class MongoKVStorage(BaseKVStorage):
await self._data.drop()
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
)
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as doc status {self.namespace}")
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
return list(self._data.find({"_id": {"$in": ids}}))
async def filter_keys(self, data: set[str]) -> set[str]:
existing_ids = [
str(x["_id"])
for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
for k, v in data.items():
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k
async def drop(self) -> None:
"""Drop the collection"""
await self._data.drop()
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
result = list(self._data.aggregate(pipeline))
counts = {}
for doc in result:
counts[doc["_id"]] = doc["count"]
return counts
async def get_docs_by_status(
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents by status"""
result = list(self._data.find({"status": status.value}))
return {
doc["_id"]: DocProcessingStatus(
content=doc["content"],
content_summary=doc.get("content_summary"),
content_length=doc["content_length"],
status=doc["status"],
created_at=doc.get("created_at"),
updated_at=doc.get("updated_at"),
chunks_count=doc.get("chunks_count", -1),
)
for doc in result
}
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""

View File

@@ -493,10 +493,14 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def get_processing_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all documents that are currently being processed"""
async def get_processing_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all processing documents"""
return await self.get_docs_by_status(DocStatus.PROCESSING)
async def get_processed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self):
"""Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into postgresql db!")

View File

@@ -87,7 +87,6 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []

View File

@@ -137,6 +137,7 @@ STORAGE_ENV_REQUIREMENTS = {
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
@@ -151,6 +152,7 @@ STORAGES = {
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",