added final, required methods and cleaned import
This commit is contained in:
@@ -1,24 +1,26 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, final
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("pymysql"):
|
||||
pm.install("pymysql")
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"pymysql, sqlalchemy library is not installed. Please install it to proceed."
|
||||
) from e
|
||||
|
||||
|
||||
class TiDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
@@ -100,6 +102,7 @@ class TiDB:
|
||||
raise
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBKVStorage(BaseKVStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -200,23 +203,16 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
await self.db.execute(merge_sql, data)
|
||||
return left_data
|
||||
|
||||
async def index_done_callback(self):
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into TiDB db!")
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
# db instance must be injected before use
|
||||
# db: TiDB
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
@@ -343,7 +339,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
@@ -481,6 +481,9 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
return []
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
Reference in New Issue
Block a user