updated clean of what implemented on DocStatusStorage
This commit is contained in:
@@ -11,6 +11,7 @@ if not pm.is_installed("pymysql"):
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from sqlalchemy import create_engine, text
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -352,7 +353,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
#################### upsert method ################
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
@@ -383,7 +384,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
) -> None:
|
||||
source_name = source_node_id
|
||||
target_name = target_node_id
|
||||
weight = edge_data["weight"]
|
||||
@@ -419,7 +420,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
}
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
@@ -452,14 +453,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||
return degree
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||
sql = SQL_TEMPLATES["get_node"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
) -> dict[str, str] | None:
|
||||
sql = SQL_TEMPLATES["get_edge"]
|
||||
param = {
|
||||
"source_name": source_node_id,
|
||||
@@ -468,9 +469,7 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[list[tuple[str, str]], None]:
|
||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||
sql = SQL_TEMPLATES["get_node_edges"]
|
||||
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
||||
res = await self.db.query(sql, param, multirows=True)
|
||||
@@ -480,6 +479,14 @@ class TiDBGraphStorage(BaseGraphStorage):
|
||||
else:
|
||||
return []
|
||||
|
||||
async def delete_node(self, node_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
|
||||
raise NotImplementedError
|
||||
|
||||
N_T = {
|
||||
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
|
||||
|
Reference in New Issue
Block a user