added final, required methods and cleaned import

This commit is contained in:
Yannick Stephan
2025-02-16 14:38:09 +01:00
parent 7848a38a45
commit 3fef8201c6
16 changed files with 209 additions and 316 deletions

View File

@@ -1,24 +1,26 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, Union
from typing import Any, Union, final
import numpy as np
import pipmaster as pm
if not pm.is_installed("pymysql"):
pm.install("pymysql")
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from lightrag.types import KnowledgeGraph
from sqlalchemy import create_engine, text
from tqdm import tqdm
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
from ..namespace import NameSpace, is_namespace
from ..utils import logger
try:
from sqlalchemy import create_engine, text
except ImportError as e:
raise ImportError(
"pymysql, sqlalchemy library is not installed. Please install it to proceed."
) from e
class TiDB:
def __init__(self, config, **kwargs):
@@ -100,6 +102,7 @@ class TiDB:
raise
@final
@dataclass
class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use
@@ -200,23 +203,16 @@ class TiDBKVStorage(BaseKVStorage):
await self.db.execute(merge_sql, data)
return left_data
async def index_done_callback(self):
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
):
logger.info("full doc and chunk data had been saved into TiDB db!")
async def index_done_callback(self) -> None:
pass
async def drop(self) -> None:
raise NotImplementedError
@final
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use
# db: TiDB
cosine_better_than_threshold: float = None
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -343,7 +339,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
async def index_done_callback(self) -> None:
raise NotImplementedError
@final
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use
@@ -481,6 +481,9 @@ class TiDBGraphStorage(BaseGraphStorage):
else:
return []
async def index_done_callback(self) -> None:
pass
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError