pre-commit fix tidb
This commit is contained in:
@@ -19,8 +19,10 @@ class TiDB(object):
|
|||||||
self.password = config.get("password", None)
|
self.password = config.get("password", None)
|
||||||
self.database = config.get("database", None)
|
self.database = config.get("database", None)
|
||||||
self.workspace = config.get("workspace", None)
|
self.workspace = config.get("workspace", None)
|
||||||
connection_string = (f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
connection_string = (
|
||||||
f"?ssl_verify_cert=true&ssl_verify_identity=true")
|
f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||||
|
f"?ssl_verify_cert=true&ssl_verify_identity=true"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.engine = create_engine(connection_string)
|
self.engine = create_engine(connection_string)
|
||||||
@@ -49,7 +51,7 @@ class TiDB(object):
|
|||||||
self, sql: str, params: dict = None, multirows: bool = False
|
self, sql: str, params: dict = None, multirows: bool = False
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
if params is None:
|
if params is None:
|
||||||
params = { "workspace": self.workspace }
|
params = {"workspace": self.workspace}
|
||||||
else:
|
else:
|
||||||
params.update({"workspace": self.workspace})
|
params.update({"workspace": self.workspace})
|
||||||
with self.engine.connect() as conn, conn.begin():
|
with self.engine.connect() as conn, conn.begin():
|
||||||
@@ -130,8 +132,8 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
"""过滤掉重复内容"""
|
"""过滤掉重复内容"""
|
||||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||||
table_name=N_T[self.namespace],
|
table_name=N_T[self.namespace],
|
||||||
id_field= N_ID[self.namespace],
|
id_field=N_ID[self.namespace],
|
||||||
ids=",".join([f"'{id}'" for id in keys])
|
ids=",".join([f"'{id}'" for id in keys]),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await self.db.query(SQL)
|
await self.db.query(SQL)
|
||||||
@@ -161,7 +163,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
]
|
]
|
||||||
contents = [v["content"] for v in data.values()]
|
contents = [v["content"] for v in data.values()]
|
||||||
batches = [
|
batches = [
|
||||||
contents[i: i + self._max_batch_size]
|
contents[i : i + self._max_batch_size]
|
||||||
for i in range(0, len(contents), self._max_batch_size)
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
]
|
]
|
||||||
embeddings_list = await asyncio.gather(
|
embeddings_list = await asyncio.gather(
|
||||||
@@ -174,26 +176,30 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
merge_sql = SQL_TEMPLATES["upsert_chunk"]
|
merge_sql = SQL_TEMPLATES["upsert_chunk"]
|
||||||
data = []
|
data = []
|
||||||
for item in list_data:
|
for item in list_data:
|
||||||
data.append({
|
data.append(
|
||||||
"id": item["__id__"],
|
{
|
||||||
"content": item["content"],
|
"id": item["__id__"],
|
||||||
"tokens": item["tokens"],
|
"content": item["content"],
|
||||||
"chunk_order_index": item["chunk_order_index"],
|
"tokens": item["tokens"],
|
||||||
"full_doc_id": item["full_doc_id"],
|
"chunk_order_index": item["chunk_order_index"],
|
||||||
"content_vector": f"{item["__vector__"].tolist()}",
|
"full_doc_id": item["full_doc_id"],
|
||||||
"workspace": self.db.workspace,
|
"content_vector": f"{item["__vector__"].tolist()}",
|
||||||
})
|
"workspace": self.db.workspace,
|
||||||
|
}
|
||||||
|
)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
if self.namespace == "full_docs":
|
if self.namespace == "full_docs":
|
||||||
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
|
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
|
||||||
data = []
|
data = []
|
||||||
for k, v in self._data.items():
|
for k, v in self._data.items():
|
||||||
data.append({
|
data.append(
|
||||||
"id": k,
|
{
|
||||||
"content": v["content"],
|
"id": k,
|
||||||
"workspace": self.db.workspace,
|
"content": v["content"],
|
||||||
})
|
"workspace": self.db.workspace,
|
||||||
|
}
|
||||||
|
)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
return left_data
|
return left_data
|
||||||
|
|
||||||
@@ -201,6 +207,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
if self.namespace in ["full_docs", "text_chunks"]:
|
if self.namespace in ["full_docs", "text_chunks"]:
|
||||||
logger.info("full doc and chunk data had been saved into TiDB db!")
|
logger.info("full doc and chunk data had been saved into TiDB db!")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = 0.2
|
cosine_better_than_threshold: float = 0.2
|
||||||
@@ -215,7 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||||
""" search from tidb vector"""
|
"""search from tidb vector"""
|
||||||
|
|
||||||
embeddings = await self.embedding_func([query])
|
embeddings = await self.embedding_func([query])
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
@@ -228,8 +235,10 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
"better_than_threshold": self.cosine_better_than_threshold,
|
"better_than_threshold": self.cosine_better_than_threshold,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = await self.db.query(SQL_TEMPLATES[self.namespace], params=params, multirows=True)
|
results = await self.db.query(
|
||||||
print("vector search result:",results)
|
SQL_TEMPLATES[self.namespace], params=params, multirows=True
|
||||||
|
)
|
||||||
|
print("vector search result:", results)
|
||||||
if not results:
|
if not results:
|
||||||
return []
|
return []
|
||||||
return results
|
return results
|
||||||
@@ -253,16 +262,16 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
]
|
]
|
||||||
contents = [v["content"] for v in data.values()]
|
contents = [v["content"] for v in data.values()]
|
||||||
batches = [
|
batches = [
|
||||||
contents[i: i + self._max_batch_size]
|
contents[i : i + self._max_batch_size]
|
||||||
for i in range(0, len(contents), self._max_batch_size)
|
for i in range(0, len(contents), self._max_batch_size)
|
||||||
]
|
]
|
||||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||||
embeddings_list = []
|
embeddings_list = []
|
||||||
for f in tqdm(
|
for f in tqdm(
|
||||||
asyncio.as_completed(embedding_tasks),
|
asyncio.as_completed(embedding_tasks),
|
||||||
total=len(embedding_tasks),
|
total=len(embedding_tasks),
|
||||||
desc="Generating embeddings",
|
desc="Generating embeddings",
|
||||||
unit="batch",
|
unit="batch",
|
||||||
):
|
):
|
||||||
embeddings = await f
|
embeddings = await f
|
||||||
embeddings_list.append(embeddings)
|
embeddings_list.append(embeddings)
|
||||||
@@ -274,27 +283,31 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|||||||
data = []
|
data = []
|
||||||
for item in list_data:
|
for item in list_data:
|
||||||
merge_sql = SQL_TEMPLATES["upsert_entity"]
|
merge_sql = SQL_TEMPLATES["upsert_entity"]
|
||||||
data.append({
|
data.append(
|
||||||
"id": item["id"],
|
{
|
||||||
"name": item["entity_name"],
|
"id": item["id"],
|
||||||
"content": item["content"],
|
"name": item["entity_name"],
|
||||||
"content_vector": f"{item["content_vector"].tolist()}",
|
"content": item["content"],
|
||||||
"workspace": self.db.workspace,
|
"content_vector": f"{item["content_vector"].tolist()}",
|
||||||
})
|
"workspace": self.db.workspace,
|
||||||
|
}
|
||||||
|
)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
elif self.namespace == "relationships":
|
elif self.namespace == "relationships":
|
||||||
data = []
|
data = []
|
||||||
for item in list_data:
|
for item in list_data:
|
||||||
merge_sql = SQL_TEMPLATES["upsert_relationship"]
|
merge_sql = SQL_TEMPLATES["upsert_relationship"]
|
||||||
data.append({
|
data.append(
|
||||||
"id": item["id"],
|
{
|
||||||
"source_name": item["src_id"],
|
"id": item["id"],
|
||||||
"target_name": item["tgt_id"],
|
"source_name": item["src_id"],
|
||||||
"content": item["content"],
|
"target_name": item["tgt_id"],
|
||||||
"content_vector": f"{item["content_vector"].tolist()}",
|
"content": item["content"],
|
||||||
"workspace": self.db.workspace,
|
"content_vector": f"{item["content_vector"].tolist()}",
|
||||||
})
|
"workspace": self.db.workspace,
|
||||||
|
}
|
||||||
|
)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
|
|
||||||
@@ -346,8 +359,7 @@ TABLES = {
|
|||||||
"""
|
"""
|
||||||
},
|
},
|
||||||
"LIGHTRAG_GRAPH_NODES": {
|
"LIGHTRAG_GRAPH_NODES": {
|
||||||
"ddl":
|
"ddl": """
|
||||||
"""
|
|
||||||
CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
||||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||||
`entity_id` VARCHAR(256) NOT NULL,
|
`entity_id` VARCHAR(256) NOT NULL,
|
||||||
@@ -362,8 +374,7 @@ TABLES = {
|
|||||||
"""
|
"""
|
||||||
},
|
},
|
||||||
"LIGHTRAG_GRAPH_EDGES": {
|
"LIGHTRAG_GRAPH_EDGES": {
|
||||||
"ddl":
|
"ddl": """
|
||||||
"""
|
|
||||||
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
||||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||||
`relation_id` VARCHAR(256) NOT NULL,
|
`relation_id` VARCHAR(256) NOT NULL,
|
||||||
@@ -400,7 +411,6 @@ SQL_TEMPLATES = {
|
|||||||
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
|
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
|
||||||
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
|
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
|
||||||
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
|
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
|
||||||
|
|
||||||
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
|
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
|
||||||
"upsert_doc_full": """
|
"upsert_doc_full": """
|
||||||
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
|
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
|
||||||
@@ -408,13 +418,12 @@ SQL_TEMPLATES = {
|
|||||||
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_chunk": """
|
"upsert_chunk": """
|
||||||
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
|
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
|
||||||
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
|
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
|
||||||
ON DUPLICATE KEY UPDATE
|
ON DUPLICATE KEY UPDATE
|
||||||
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
|
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
|
||||||
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
|
|
||||||
# SQL for VectorStorage
|
# SQL for VectorStorage
|
||||||
"entities": """SELECT n.name as entity_name FROM
|
"entities": """SELECT n.name as entity_name FROM
|
||||||
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
|
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
|
||||||
@@ -428,19 +437,18 @@ SQL_TEMPLATES = {
|
|||||||
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
||||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
|
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
|
||||||
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
|
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
|
||||||
|
|
||||||
"upsert_entity": """
|
"upsert_entity": """
|
||||||
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
|
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
|
||||||
VALUES(:id, :name, :content, :content_vector, :workspace)
|
VALUES(:id, :name, :content, :content_vector, :workspace)
|
||||||
ON DUPLICATE KEY UPDATE
|
ON DUPLICATE KEY UPDATE
|
||||||
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
|
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
|
||||||
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||||
""",
|
""",
|
||||||
"upsert_relationship": """
|
"upsert_relationship": """
|
||||||
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
|
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
|
||||||
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
|
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
|
||||||
ON DUPLICATE KEY UPDATE
|
ON DUPLICATE KEY UPDATE
|
||||||
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
|
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
|
||||||
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||||
"""
|
""",
|
||||||
}
|
}
|
||||||
|
@@ -80,6 +80,7 @@ ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBS
|
|||||||
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
||||||
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
"""
|
"""
|
||||||
Ensure that there is always an event loop available.
|
Ensure that there is always an event loop available.
|
||||||
|
@@ -13,11 +13,11 @@ openai
|
|||||||
oracledb
|
oracledb
|
||||||
pymilvus
|
pymilvus
|
||||||
pymongo
|
pymongo
|
||||||
|
pymysql
|
||||||
pyvis
|
pyvis
|
||||||
tenacity
|
|
||||||
# lmdeploy[all]
|
# lmdeploy[all]
|
||||||
sqlalchemy
|
sqlalchemy
|
||||||
pymysql
|
tenacity
|
||||||
|
|
||||||
|
|
||||||
# LLM packages
|
# LLM packages
|
||||||
|
Reference in New Issue
Block a user