added final, required methods and cleaned import

This commit is contained in:
Yannick Stephan
2025-02-16 14:38:09 +01:00
parent 7848a38a45
commit 3fef8201c6
16 changed files with 209 additions and 316 deletions

View File

@@ -1,22 +1,11 @@
import os
from dataclasses import dataclass
import numpy as np
import pipmaster as pm
import configparser
from tqdm.asyncio import tqdm as tqdm_async
import asyncio
if not pm.is_installed("pymongo"):
pm.install("pymongo")
if not pm.is_installed("motor"):
pm.install("motor")
from typing import Any, List, Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
from typing import Any, List, Union, final
from ..base import (
BaseGraphStorage,
@@ -30,11 +19,22 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
raise ImportError(
"motor, pymongo library is not installed. Please install it to proceed."
) from e
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@final
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
@@ -115,6 +115,7 @@ class MongoKVStorage(BaseKVStorage):
await self._data.drop()
@final
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
@@ -210,7 +211,15 @@ class MongoDocStatusStorage(DocStatusStorage):
"""Get all procesed documents"""
return await self.get_docs_by_status(DocStatus.PROCESSED)
async def index_done_callback(self) -> None:
# Implement the method here
pass
async def update_doc_status(self, data: dict[str, Any]) -> None:
raise NotImplementedError
@final
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""
@@ -774,11 +783,13 @@ class MongoGraphStorage(BaseGraphStorage):
return result
async def index_done_callback(self) -> None:
pass
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")