cleaned code
This commit is contained in:
@@ -14,12 +14,12 @@ if not pm.is_installed("sqlalchemy"):
|
||||
from sqlalchemy import create_engine, text
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
|
||||
from ..utils import logger
|
||||
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
|
||||
from ..namespace import NameSpace, is_namespace
|
||||
from ..utils import logger
|
||||
|
||||
|
||||
class TiDB(object):
|
||||
class TiDB:
|
||||
def __init__(self, config, **kwargs):
|
||||
self.host = config.get("host", None)
|
||||
self.port = config.get("port", None)
|
||||
@@ -108,12 +108,12 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
"""Fetch doc_full data by id."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"id": id}
|
||||
# print("get_by_id:"+SQL)
|
||||
return await self.db.query(SQL, params)
|
||||
response = await self.db.query(SQL, params)
|
||||
return response if response else None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
|
||||
@@ -178,7 +178,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
"tokens": item["tokens"],
|
||||
"chunk_order_index": item["chunk_order_index"],
|
||||
"full_doc_id": item["full_doc_id"],
|
||||
"content_vector": f"{item['__vector__'].tolist()}",
|
||||
"content_vector": f'{item["__vector__"].tolist()}',
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
)
|
||||
@@ -222,8 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
"""search from tidb vector"""
|
||||
|
||||
"""Search from tidb vector"""
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
|
||||
@@ -286,7 +285,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
"id": item["id"],
|
||||
"name": item["entity_name"],
|
||||
"content": item["content"],
|
||||
"content_vector": f"{item['content_vector'].tolist()}",
|
||||
"content_vector": f'{item["content_vector"].tolist()}',
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
# update entity_id if node inserted by graph_storage_instance before
|
||||
@@ -308,7 +307,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
"source_name": item["src_id"],
|
||||
"target_name": item["tgt_id"],
|
||||
"content": item["content"],
|
||||
"content_vector": f"{item['content_vector'].tolist()}",
|
||||
"content_vector": f'{item["content_vector"].tolist()}',
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
# update relation_id if node inserted by graph_storage_instance before
|
||||
|
Reference in New Issue
Block a user