use pre-commit reformat

This commit is contained in:
tmuife
2024-11-18 13:52:49 +08:00
parent 1123ccfbd0
commit e5f2aa3a30

View File

@@ -114,7 +114,9 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database")
async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
async def query(
self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
@@ -187,7 +189,9 @@ class OracleKVStorage(BaseKVStorage):
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
# print("get_by_ids:"+SQL)
# print(params)
@@ -201,8 +205,9 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
ids=",".join([f"'{id}'" for id in keys]))
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
)
params = {"workspace": self.db.workspace}
try:
await self.db.query(SQL, params)
@@ -248,14 +253,15 @@ class OracleKVStorage(BaseKVStorage):
# print(list_data)
for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"]
data = {"check_id":item["__id__"],
data = {
"check_id": item["__id__"],
"id": item["__id__"],
"content": item["content"],
"workspace": self.db.workspace,
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector":item["__vector__"]
"content_vector": item["__vector__"],
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
@@ -268,7 +274,7 @@ class OracleKVStorage(BaseKVStorage):
"check_id": k,
"id": k,
"content": v["content"],
"workspace":self.db.workspace
"workspace": self.db.workspace,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
@@ -354,7 +360,7 @@ class OracleGraphStorage(BaseGraphStorage):
"description": description,
"source_chunk_id": source_id,
"content": content,
"content_vector":content_vector
"content_vector": content_vector,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
@@ -371,7 +377,9 @@ class OracleGraphStorage(BaseGraphStorage):
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
logger.debug(
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
)
content = keywords + source_name + target_name + description
contents = [content]
@@ -394,7 +402,7 @@ class OracleGraphStorage(BaseGraphStorage):
"description": description,
"source_chunk_id": source_chunk_id,
"content": content,
"content_vector":content_vector
"content_vector": content_vector,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
@@ -428,10 +436,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在"""
SQL = SQL_TEMPLATES["has_node"]
params = {
"workspace":self.db.workspace,
"node_id":node_id
}
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
# print(self.db.workspace, node_id)
res = await self.db.query(SQL, params)
@@ -448,7 +453,7 @@ class OracleGraphStorage(BaseGraphStorage):
params = {
"workspace": self.db.workspace,
"source_node_id": source_node_id,
"target_node_id":target_node_id
"target_node_id": target_node_id,
}
# print(SQL)
res = await self.db.query(SQL, params)
@@ -462,10 +467,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"]
params = {
"workspace":self.db.workspace,
"node_id":node_id
}
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
res = await self.db.query(SQL, params)
if res:
@@ -484,10 +486,7 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> Union[dict, None]:
"""根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"]
params = {
"workspace":self.db.workspace,
"node_id":node_id
}
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(self.db.workspace, node_id)
# print(SQL)
res = await self.db.query(SQL, params)
@@ -506,7 +505,7 @@ class OracleGraphStorage(BaseGraphStorage):
params = {
"workspace": self.db.workspace,
"source_node_id": source_node_id,
"target_node_id":target_node_id
"target_node_id": target_node_id,
}
res = await self.db.query(SQL, params)
if res:
@@ -520,10 +519,7 @@ class OracleGraphStorage(BaseGraphStorage):
"""根据节点id获取节点的所有边"""
if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"]
params = {
"workspace":self.db.workspace,
"source_node_id":source_node_id
}
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
data = [(i["source_name"], i["target_name"]) for i in res]
@@ -548,6 +544,8 @@ class OracleGraphStorage(BaseGraphStorage):
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
return res
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
@@ -732,5 +730,5 @@ SQL_TEMPLATES = {
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
fetch first :limit rows only"""
fetch first :limit rows only""",
}