implemented method and cleaned the mess

This commit is contained in:
Yannick Stephan
2025-02-08 23:18:12 +01:00
parent eb552afcdc
commit cff415d91f
7 changed files with 66 additions and 125 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Union
from typing import Any, TypeVar, Union
import numpy as np
import pipmaster as pm
@@ -108,7 +108,7 @@ class TiDBKVStorage(BaseKVStorage):
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict, None]:
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
"""根据 id 获取 doc_full 数据."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"id": id}
@@ -122,16 +122,14 @@ class TiDBKVStorage(BaseKVStorage):
return None
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
"""根据 id 获取 doc_chunks 数据"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
# print("get_by_ids:"+SQL)
res = await self.db.query(SQL, multirows=True)
if res:
data = res # [{"data":i} for i in res]
# print(data)
return data
else:
return None
@@ -158,7 +156,7 @@ class TiDBKVStorage(BaseKVStorage):
return data
################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, dict]):
async def upsert(self, data: dict[str, Any]) -> None:
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
@@ -335,6 +333,12 @@ class TiDBVectorDBStorage(BaseVectorStorage):
merge_sql = SQL_TEMPLATES["insert_relationship"]
await self.db.execute(merge_sql, data)
async def get_by_status_and_ids(
self, status: str
) -> Union[list[dict], None]:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
@dataclass
class TiDBGraphStorage(BaseGraphStorage):