use pre-commit reformat
This commit is contained in:
@@ -114,7 +114,9 @@ class OracleDB:
|
|||||||
|
|
||||||
logger.info("Finished check all tables in Oracle database")
|
logger.info("Finished check all tables in Oracle database")
|
||||||
|
|
||||||
async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
|
async def query(
|
||||||
|
self, sql: str, params: dict = None, multirows: bool = False
|
||||||
|
) -> Union[dict, None]:
|
||||||
async with self.pool.acquire() as connection:
|
async with self.pool.acquire() as connection:
|
||||||
connection.inputtypehandler = self.input_type_handler
|
connection.inputtypehandler = self.input_type_handler
|
||||||
connection.outputtypehandler = self.output_type_handler
|
connection.outputtypehandler = self.output_type_handler
|
||||||
@@ -187,7 +189,9 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||||
"""根据 id 获取 doc_chunks 数据"""
|
"""根据 id 获取 doc_chunks 数据"""
|
||||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
|
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
|
)
|
||||||
params = {"workspace": self.db.workspace}
|
params = {"workspace": self.db.workspace}
|
||||||
# print("get_by_ids:"+SQL)
|
# print("get_by_ids:"+SQL)
|
||||||
# print(params)
|
# print(params)
|
||||||
@@ -201,8 +205,9 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||||
"""过滤掉重复内容"""
|
"""过滤掉重复内容"""
|
||||||
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
|
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||||
ids=",".join([f"'{id}'" for id in keys]))
|
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
|
||||||
|
)
|
||||||
params = {"workspace": self.db.workspace}
|
params = {"workspace": self.db.workspace}
|
||||||
try:
|
try:
|
||||||
await self.db.query(SQL, params)
|
await self.db.query(SQL, params)
|
||||||
@@ -248,14 +253,15 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
# print(list_data)
|
# print(list_data)
|
||||||
for item in list_data:
|
for item in list_data:
|
||||||
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
||||||
data = {"check_id":item["__id__"],
|
data = {
|
||||||
|
"check_id": item["__id__"],
|
||||||
"id": item["__id__"],
|
"id": item["__id__"],
|
||||||
"content": item["content"],
|
"content": item["content"],
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"tokens": item["tokens"],
|
"tokens": item["tokens"],
|
||||||
"chunk_order_index": item["chunk_order_index"],
|
"chunk_order_index": item["chunk_order_index"],
|
||||||
"full_doc_id": item["full_doc_id"],
|
"full_doc_id": item["full_doc_id"],
|
||||||
"content_vector":item["__vector__"]
|
"content_vector": item["__vector__"],
|
||||||
}
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
@@ -268,7 +274,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
"check_id": k,
|
"check_id": k,
|
||||||
"id": k,
|
"id": k,
|
||||||
"content": v["content"],
|
"content": v["content"],
|
||||||
"workspace":self.db.workspace
|
"workspace": self.db.workspace,
|
||||||
}
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
@@ -354,7 +360,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
"description": description,
|
"description": description,
|
||||||
"source_chunk_id": source_id,
|
"source_chunk_id": source_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
"content_vector":content_vector
|
"content_vector": content_vector,
|
||||||
}
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
@@ -371,7 +377,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
keywords = edge_data["keywords"]
|
keywords = edge_data["keywords"]
|
||||||
description = edge_data["description"]
|
description = edge_data["description"]
|
||||||
source_chunk_id = edge_data["source_id"]
|
source_chunk_id = edge_data["source_id"]
|
||||||
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
|
logger.debug(
|
||||||
|
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
|
||||||
|
)
|
||||||
|
|
||||||
content = keywords + source_name + target_name + description
|
content = keywords + source_name + target_name + description
|
||||||
contents = [content]
|
contents = [content]
|
||||||
@@ -394,7 +402,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
"description": description,
|
"description": description,
|
||||||
"source_chunk_id": source_chunk_id,
|
"source_chunk_id": source_chunk_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
"content_vector":content_vector
|
"content_vector": content_vector,
|
||||||
}
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
@@ -428,10 +436,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
"""根据节点id检查节点是否存在"""
|
"""根据节点id检查节点是否存在"""
|
||||||
SQL = SQL_TEMPLATES["has_node"]
|
SQL = SQL_TEMPLATES["has_node"]
|
||||||
params = {
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
"workspace":self.db.workspace,
|
|
||||||
"node_id":node_id
|
|
||||||
}
|
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
# print(self.db.workspace, node_id)
|
# print(self.db.workspace, node_id)
|
||||||
res = await self.db.query(SQL, params)
|
res = await self.db.query(SQL, params)
|
||||||
@@ -448,7 +453,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
params = {
|
params = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"source_node_id": source_node_id,
|
"source_node_id": source_node_id,
|
||||||
"target_node_id":target_node_id
|
"target_node_id": target_node_id,
|
||||||
}
|
}
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL, params)
|
res = await self.db.query(SQL, params)
|
||||||
@@ -462,10 +467,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""根据节点id获取节点的度"""
|
"""根据节点id获取节点的度"""
|
||||||
SQL = SQL_TEMPLATES["node_degree"]
|
SQL = SQL_TEMPLATES["node_degree"]
|
||||||
params = {
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
"workspace":self.db.workspace,
|
|
||||||
"node_id":node_id
|
|
||||||
}
|
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL, params)
|
res = await self.db.query(SQL, params)
|
||||||
if res:
|
if res:
|
||||||
@@ -484,10 +486,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
"""根据节点id获取节点数据"""
|
"""根据节点id获取节点数据"""
|
||||||
SQL = SQL_TEMPLATES["get_node"]
|
SQL = SQL_TEMPLATES["get_node"]
|
||||||
params = {
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
"workspace":self.db.workspace,
|
|
||||||
"node_id":node_id
|
|
||||||
}
|
|
||||||
# print(self.db.workspace, node_id)
|
# print(self.db.workspace, node_id)
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL, params)
|
res = await self.db.query(SQL, params)
|
||||||
@@ -506,7 +505,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
params = {
|
params = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"source_node_id": source_node_id,
|
"source_node_id": source_node_id,
|
||||||
"target_node_id":target_node_id
|
"target_node_id": target_node_id,
|
||||||
}
|
}
|
||||||
res = await self.db.query(SQL, params)
|
res = await self.db.query(SQL, params)
|
||||||
if res:
|
if res:
|
||||||
@@ -520,10 +519,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
"""根据节点id获取节点的所有边"""
|
"""根据节点id获取节点的所有边"""
|
||||||
if await self.has_node(source_node_id):
|
if await self.has_node(source_node_id):
|
||||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||||
params = {
|
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
||||||
"workspace":self.db.workspace,
|
|
||||||
"source_node_id":source_node_id
|
|
||||||
}
|
|
||||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
data = [(i["source_name"], i["target_name"]) for i in res]
|
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||||
@@ -548,6 +544,8 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||||
@@ -732,5 +730,5 @@ SQL_TEMPLATES = {
|
|||||||
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
||||||
WHERE t1.workspace=:workspace
|
WHERE t1.workspace=:workspace
|
||||||
order by t1.CREATETIME DESC
|
order by t1.CREATETIME DESC
|
||||||
fetch first :limit rows only"""
|
fetch first :limit rows only""",
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user