Update oracle_impl.py
This commit is contained in:
@@ -114,7 +114,9 @@ class OracleDB:
|
||||
|
||||
logger.info("Finished check all tables in Oracle database")
|
||||
|
||||
async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
|
||||
async def query(
|
||||
self, sql: str, params: dict = None, multirows: bool = False
|
||||
) -> Union[dict, None]:
|
||||
async with self.pool.acquire() as connection:
|
||||
connection.inputtypehandler = self.input_type_handler
|
||||
connection.outputtypehandler = self.output_type_handler
|
||||
@@ -187,7 +189,9 @@ class OracleKVStorage(BaseKVStorage):
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
params = {"workspace": self.db.workspace}
|
||||
# print("get_by_ids:"+SQL)
|
||||
# print(params)
|
||||
@@ -201,8 +205,9 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
|
||||
ids=",".join([f"'{id}'" for id in keys]))
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
|
||||
)
|
||||
params = {"workspace": self.db.workspace}
|
||||
try:
|
||||
await self.db.query(SQL, params)
|
||||
@@ -248,14 +253,15 @@ class OracleKVStorage(BaseKVStorage):
|
||||
# print(list_data)
|
||||
for item in list_data:
|
||||
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
||||
data = {"check_id":item["__id__"],
|
||||
data = {
|
||||
"check_id": item["__id__"],
|
||||
"id": item["__id__"],
|
||||
"content": item["content"],
|
||||
"workspace": self.db.workspace,
|
||||
"tokens": item["tokens"],
|
||||
"chunk_order_index": item["chunk_order_index"],
|
||||
"full_doc_id": item["full_doc_id"],
|
||||
"content_vector":item["__vector__"]
|
||||
"content_vector": item["__vector__"],
|
||||
}
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql, data)
|
||||
@@ -268,7 +274,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
"check_id": k,
|
||||
"id": k,
|
||||
"content": v["content"],
|
||||
"workspace":self.db.workspace
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql, data)
|
||||
@@ -354,7 +360,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
"description": description,
|
||||
"source_chunk_id": source_id,
|
||||
"content": content,
|
||||
"content_vector":content_vector
|
||||
"content_vector": content_vector,
|
||||
}
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql, data)
|
||||
@@ -371,7 +377,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
keywords = edge_data["keywords"]
|
||||
description = edge_data["description"]
|
||||
source_chunk_id = edge_data["source_id"]
|
||||
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
|
||||
logger.debug(
|
||||
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
|
||||
)
|
||||
|
||||
content = keywords + source_name + target_name + description
|
||||
contents = [content]
|
||||
@@ -394,7 +402,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
"description": description,
|
||||
"source_chunk_id": source_chunk_id,
|
||||
"content": content,
|
||||
"content_vector":content_vector
|
||||
"content_vector": content_vector,
|
||||
}
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql, data)
|
||||
@@ -428,10 +436,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""根据节点id检查节点是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_node"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"node_id":node_id
|
||||
}
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
# print(SQL)
|
||||
# print(self.db.workspace, node_id)
|
||||
res = await self.db.query(SQL, params)
|
||||
@@ -448,7 +453,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
"source_node_id": source_node_id,
|
||||
"target_node_id":target_node_id
|
||||
"target_node_id": target_node_id,
|
||||
}
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
@@ -462,10 +467,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""根据节点id获取节点的度"""
|
||||
SQL = SQL_TEMPLATES["node_degree"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"node_id":node_id
|
||||
}
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
@@ -484,10 +486,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""根据节点id获取节点数据"""
|
||||
SQL = SQL_TEMPLATES["get_node"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"node_id":node_id
|
||||
}
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
# print(self.db.workspace, node_id)
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
@@ -506,7 +505,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
"source_node_id": source_node_id,
|
||||
"target_node_id":target_node_id
|
||||
"target_node_id": target_node_id,
|
||||
}
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
@@ -520,10 +519,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
"""根据节点id获取节点的所有边"""
|
||||
if await self.has_node(source_node_id):
|
||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||
params = {
|
||||
"workspace":self.db.workspace,
|
||||
"source_node_id":source_node_id
|
||||
}
|
||||
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||
if res:
|
||||
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||
@@ -556,6 +552,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
if res:
|
||||
return res
|
||||
|
||||
|
||||
N_T = {
|
||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
|
Reference in New Issue
Block a user