cleaned code

This commit is contained in:
Yannick Stephan
2025-02-09 19:51:05 +01:00
parent 55cfb4dab1
commit 6480ddee5d
8 changed files with 77 additions and 69 deletions

View File

@@ -14,12 +14,12 @@ if not pm.is_installed("sqlalchemy"):
from sqlalchemy import create_engine, text
from tqdm import tqdm
from ..base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
from ..utils import logger
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
from ..namespace import NameSpace, is_namespace
from ..utils import logger
class TiDB(object):
class TiDB:
def __init__(self, config, **kwargs):
self.host = config.get("host", None)
self.port = config.get("port", None)
@@ -108,12 +108,12 @@ class TiDBKVStorage(BaseKVStorage):
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any]:
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
"""Fetch doc_full data by id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"id": id}
# print("get_by_id:"+SQL)
return await self.db.query(SQL, params)
response = await self.db.query(SQL, params)
return response if response else None
# Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
@@ -178,7 +178,7 @@ class TiDBKVStorage(BaseKVStorage):
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": f"{item['__vector__'].tolist()}",
"content_vector": f'{item["__vector__"].tolist()}',
"workspace": self.db.workspace,
}
)
@@ -222,8 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
)
async def query(self, query: str, top_k: int) -> list[dict]:
"""search from tidb vector"""
"""Search from tidb vector"""
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
@@ -286,7 +285,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f"{item['content_vector'].tolist()}",
"content_vector": f'{item["content_vector"].tolist()}',
"workspace": self.db.workspace,
}
# update entity_id if node inserted by graph_storage_instance before
@@ -308,7 +307,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f"{item['content_vector'].tolist()}",
"content_vector": f'{item["content_vector"].tolist()}',
"workspace": self.db.workspace,
}
# update relation_id if node inserted by graph_storage_instance before