updated clean of what implemented on DocStatusStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:53:59 +01:00
parent 71a18d1de9
commit 882190a515
9 changed files with 164 additions and 168 deletions

View File

@@ -11,6 +11,7 @@ if not pm.is_installed("pymysql"):
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from lightrag.types import KnowledgeGraph
from sqlalchemy import create_engine, text
from tqdm import tqdm
@@ -352,7 +353,7 @@ class TiDBGraphStorage(BaseGraphStorage):
self._max_batch_size = self.global_config["embedding_batch_num"]
#################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id
entity_type = node_data["entity_type"]
description = node_data["description"]
@@ -383,7 +384,7 @@ class TiDBGraphStorage(BaseGraphStorage):
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
) -> None:
source_name = source_node_id
target_name = target_node_id
weight = edge_data["weight"]
@@ -419,7 +420,7 @@ class TiDBGraphStorage(BaseGraphStorage):
}
await self.db.execute(merge_sql, data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
@@ -452,14 +453,14 @@ class TiDBGraphStorage(BaseGraphStorage):
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
return degree
async def get_node(self, node_id: str) -> Union[dict, None]:
async def get_node(self, node_id: str) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_node"]
param = {"name": node_id, "workspace": self.db.workspace}
return await self.db.query(sql, param)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_edge"]
param = {
"source_name": source_node_id,
@@ -468,9 +469,7 @@ class TiDBGraphStorage(BaseGraphStorage):
}
return await self.db.query(sql, param)
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
sql = SQL_TEMPLATES["get_node_edges"]
param = {"source_name": source_node_id, "workspace": self.db.workspace}
res = await self.db.query(sql, param, multirows=True)
@@ -480,6 +479,14 @@ class TiDBGraphStorage(BaseGraphStorage):
else:
return []
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError
N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",