pre-commit fix tidb

This commit is contained in:
Weaxs
2024-12-12 10:21:51 +08:00
parent 8ef5a6b8cd
commit 288985eab4
3 changed files with 73 additions and 64 deletions

View File

@@ -19,8 +19,10 @@ class TiDB(object):
self.password = config.get("password", None) self.password = config.get("password", None)
self.database = config.get("database", None) self.database = config.get("database", None)
self.workspace = config.get("workspace", None) self.workspace = config.get("workspace", None)
connection_string = (f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" connection_string = (
f"?ssl_verify_cert=true&ssl_verify_identity=true") f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
f"?ssl_verify_cert=true&ssl_verify_identity=true"
)
try: try:
self.engine = create_engine(connection_string) self.engine = create_engine(connection_string)
@@ -49,7 +51,7 @@ class TiDB(object):
self, sql: str, params: dict = None, multirows: bool = False self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]: ) -> Union[dict, None]:
if params is None: if params is None:
params = { "workspace": self.workspace } params = {"workspace": self.workspace}
else: else:
params.update({"workspace": self.workspace}) params.update({"workspace": self.workspace})
with self.engine.connect() as conn, conn.begin(): with self.engine.connect() as conn, conn.begin():
@@ -130,8 +132,8 @@ class TiDBKVStorage(BaseKVStorage):
"""过滤掉重复内容""" """过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], table_name=N_T[self.namespace],
id_field= N_ID[self.namespace], id_field=N_ID[self.namespace],
ids=",".join([f"'{id}'" for id in keys]) ids=",".join([f"'{id}'" for id in keys]),
) )
try: try:
await self.db.query(SQL) await self.db.query(SQL)
@@ -161,7 +163,7 @@ class TiDBKVStorage(BaseKVStorage):
] ]
contents = [v["content"] for v in data.values()] contents = [v["content"] for v in data.values()]
batches = [ batches = [
contents[i: i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
embeddings_list = await asyncio.gather( embeddings_list = await asyncio.gather(
@@ -174,7 +176,8 @@ class TiDBKVStorage(BaseKVStorage):
merge_sql = SQL_TEMPLATES["upsert_chunk"] merge_sql = SQL_TEMPLATES["upsert_chunk"]
data = [] data = []
for item in list_data: for item in list_data:
data.append({ data.append(
{
"id": item["__id__"], "id": item["__id__"],
"content": item["content"], "content": item["content"],
"tokens": item["tokens"], "tokens": item["tokens"],
@@ -182,18 +185,21 @@ class TiDBKVStorage(BaseKVStorage):
"full_doc_id": item["full_doc_id"], "full_doc_id": item["full_doc_id"],
"content_vector": f"{item["__vector__"].tolist()}", "content_vector": f"{item["__vector__"].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
}) }
)
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
if self.namespace == "full_docs": if self.namespace == "full_docs":
merge_sql = SQL_TEMPLATES["upsert_doc_full"] merge_sql = SQL_TEMPLATES["upsert_doc_full"]
data = [] data = []
for k, v in self._data.items(): for k, v in self._data.items():
data.append({ data.append(
{
"id": k, "id": k,
"content": v["content"], "content": v["content"],
"workspace": self.db.workspace, "workspace": self.db.workspace,
}) }
)
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
return left_data return left_data
@@ -201,6 +207,7 @@ class TiDBKVStorage(BaseKVStorage):
if self.namespace in ["full_docs", "text_chunks"]: if self.namespace in ["full_docs", "text_chunks"]:
logger.info("full doc and chunk data had been saved into TiDB db!") logger.info("full doc and chunk data had been saved into TiDB db!")
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2 cosine_better_than_threshold: float = 0.2
@@ -215,7 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
) )
async def query(self, query: str, top_k: int) -> list[dict]: async def query(self, query: str, top_k: int) -> list[dict]:
""" search from tidb vector""" """search from tidb vector"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
@@ -228,8 +235,10 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"better_than_threshold": self.cosine_better_than_threshold, "better_than_threshold": self.cosine_better_than_threshold,
} }
results = await self.db.query(SQL_TEMPLATES[self.namespace], params=params, multirows=True) results = await self.db.query(
print("vector search result:",results) SQL_TEMPLATES[self.namespace], params=params, multirows=True
)
print("vector search result:", results)
if not results: if not results:
return [] return []
return results return results
@@ -253,7 +262,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
] ]
contents = [v["content"] for v in data.values()] contents = [v["content"] for v in data.values()]
batches = [ batches = [
contents[i: i + self._max_batch_size] contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size) for i in range(0, len(contents), self._max_batch_size)
] ]
embedding_tasks = [self.embedding_func(batch) for batch in batches] embedding_tasks = [self.embedding_func(batch) for batch in batches]
@@ -274,27 +283,31 @@ class TiDBVectorDBStorage(BaseVectorStorage):
data = [] data = []
for item in list_data: for item in list_data:
merge_sql = SQL_TEMPLATES["upsert_entity"] merge_sql = SQL_TEMPLATES["upsert_entity"]
data.append({ data.append(
{
"id": item["id"], "id": item["id"],
"name": item["entity_name"], "name": item["entity_name"],
"content": item["content"], "content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}", "content_vector": f"{item["content_vector"].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
}) }
)
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
elif self.namespace == "relationships": elif self.namespace == "relationships":
data = [] data = []
for item in list_data: for item in list_data:
merge_sql = SQL_TEMPLATES["upsert_relationship"] merge_sql = SQL_TEMPLATES["upsert_relationship"]
data.append({ data.append(
{
"id": item["id"], "id": item["id"],
"source_name": item["src_id"], "source_name": item["src_id"],
"target_name": item["tgt_id"], "target_name": item["tgt_id"],
"content": item["content"], "content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}", "content_vector": f"{item["content_vector"].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
}) }
)
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
@@ -346,8 +359,7 @@ TABLES = {
""" """
}, },
"LIGHTRAG_GRAPH_NODES": { "LIGHTRAG_GRAPH_NODES": {
"ddl": "ddl": """
"""
CREATE TABLE LIGHTRAG_GRAPH_NODES ( CREATE TABLE LIGHTRAG_GRAPH_NODES (
`id` BIGINT PRIMARY KEY AUTO_RANDOM, `id` BIGINT PRIMARY KEY AUTO_RANDOM,
`entity_id` VARCHAR(256) NOT NULL, `entity_id` VARCHAR(256) NOT NULL,
@@ -362,8 +374,7 @@ TABLES = {
""" """
}, },
"LIGHTRAG_GRAPH_EDGES": { "LIGHTRAG_GRAPH_EDGES": {
"ddl": "ddl": """
"""
CREATE TABLE LIGHTRAG_GRAPH_EDGES ( CREATE TABLE LIGHTRAG_GRAPH_EDGES (
`id` BIGINT PRIMARY KEY AUTO_RANDOM, `id` BIGINT PRIMARY KEY AUTO_RANDOM,
`relation_id` VARCHAR(256) NOT NULL, `relation_id` VARCHAR(256) NOT NULL,
@@ -400,7 +411,6 @@ SQL_TEMPLATES = {
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace", "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace", "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace", "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE) # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
"upsert_doc_full": """ "upsert_doc_full": """
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace) INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
@@ -414,7 +424,6 @@ SQL_TEMPLATES = {
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""", """,
# SQL for VectorStorage # SQL for VectorStorage
"entities": """SELECT n.name as entity_name FROM "entities": """SELECT n.name as entity_name FROM
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
@@ -428,7 +437,6 @@ SQL_TEMPLATES = {
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""", WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
"upsert_entity": """ "upsert_entity": """
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace) INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
VALUES(:id, :name, :content, :content_vector, :workspace) VALUES(:id, :name, :content, :content_vector, :workspace)
@@ -442,5 +450,5 @@ SQL_TEMPLATES = {
ON DUPLICATE KEY UPDATE ON DUPLICATE KEY UPDATE
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""" """,
} }

View File

@@ -80,6 +80,7 @@ ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBS
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
""" """
Ensure that there is always an event loop available. Ensure that there is always an event loop available.

View File

@@ -13,11 +13,11 @@ openai
oracledb oracledb
pymilvus pymilvus
pymongo pymongo
pymysql
pyvis pyvis
tenacity
# lmdeploy[all] # lmdeploy[all]
sqlalchemy sqlalchemy
pymysql tenacity
# LLM packages # LLM packages