use pre-commit reformat
This commit is contained in:
@@ -114,7 +114,9 @@ class OracleDB:
|
|||||||
|
|
||||||
logger.info("Finished check all tables in Oracle database")
|
logger.info("Finished check all tables in Oracle database")
|
||||||
|
|
||||||
async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
|
async def query(
|
||||||
|
self, sql: str, params: dict = None, multirows: bool = False
|
||||||
|
) -> Union[dict, None]:
|
||||||
async with self.pool.acquire() as connection:
|
async with self.pool.acquire() as connection:
|
||||||
connection.inputtypehandler = self.input_type_handler
|
connection.inputtypehandler = self.input_type_handler
|
||||||
connection.outputtypehandler = self.output_type_handler
|
connection.outputtypehandler = self.output_type_handler
|
||||||
@@ -174,9 +176,9 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||||
"""根据 id 获取 doc_full 数据."""
|
"""根据 id 获取 doc_full 数据."""
|
||||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||||
params = {"workspace":self.db.workspace, "id":id}
|
params = {"workspace": self.db.workspace, "id": id}
|
||||||
# print("get_by_id:"+SQL)
|
# print("get_by_id:"+SQL)
|
||||||
res = await self.db.query(SQL,params)
|
res = await self.db.query(SQL, params)
|
||||||
if res:
|
if res:
|
||||||
data = res # {"data":res}
|
data = res # {"data":res}
|
||||||
# print (data)
|
# print (data)
|
||||||
@@ -187,11 +189,13 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
# Query by id
|
# Query by id
|
||||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||||
"""根据 id 获取 doc_chunks 数据"""
|
"""根据 id 获取 doc_chunks 数据"""
|
||||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
|
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||||
params = {"workspace":self.db.workspace}
|
ids=",".join([f"'{id}'" for id in ids])
|
||||||
#print("get_by_ids:"+SQL)
|
)
|
||||||
#print(params)
|
params = {"workspace": self.db.workspace}
|
||||||
res = await self.db.query(SQL,params, multirows=True)
|
# print("get_by_ids:"+SQL)
|
||||||
|
# print(params)
|
||||||
|
res = await self.db.query(SQL, params, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
data = res # [{"data":i} for i in res]
|
data = res # [{"data":i} for i in res]
|
||||||
# print(data)
|
# print(data)
|
||||||
@@ -201,16 +205,17 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||||
"""过滤掉重复内容"""
|
"""过滤掉重复内容"""
|
||||||
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
|
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||||
ids=",".join([f"'{id}'" for id in keys]))
|
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
|
||||||
params = {"workspace":self.db.workspace}
|
)
|
||||||
|
params = {"workspace": self.db.workspace}
|
||||||
try:
|
try:
|
||||||
await self.db.query(SQL, params)
|
await self.db.query(SQL, params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Oracle database error: {e}")
|
logger.error(f"Oracle database error: {e}")
|
||||||
print(SQL)
|
print(SQL)
|
||||||
print(params)
|
print(params)
|
||||||
res = await self.db.query(SQL, params,multirows=True)
|
res = await self.db.query(SQL, params, multirows=True)
|
||||||
data = None
|
data = None
|
||||||
if res:
|
if res:
|
||||||
exist_keys = [key["id"] for key in res]
|
exist_keys = [key["id"] for key in res]
|
||||||
@@ -248,15 +253,16 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
# print(list_data)
|
# print(list_data)
|
||||||
for item in list_data:
|
for item in list_data:
|
||||||
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
||||||
data = {"check_id":item["__id__"],
|
data = {
|
||||||
"id":item["__id__"],
|
"check_id": item["__id__"],
|
||||||
"content":item["content"],
|
"id": item["__id__"],
|
||||||
"workspace":self.db.workspace,
|
"content": item["content"],
|
||||||
"tokens":item["tokens"],
|
"workspace": self.db.workspace,
|
||||||
"chunk_order_index":item["chunk_order_index"],
|
"tokens": item["tokens"],
|
||||||
"full_doc_id":item["full_doc_id"],
|
"chunk_order_index": item["chunk_order_index"],
|
||||||
"content_vector":item["__vector__"]
|
"full_doc_id": item["full_doc_id"],
|
||||||
}
|
"content_vector": item["__vector__"],
|
||||||
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
|
|
||||||
@@ -265,11 +271,11 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
# values.clear()
|
# values.clear()
|
||||||
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
||||||
data = {
|
data = {
|
||||||
"check_id":k,
|
"check_id": k,
|
||||||
"id":k,
|
"id": k,
|
||||||
"content":v["content"],
|
"content": v["content"],
|
||||||
"workspace":self.db.workspace
|
"workspace": self.db.workspace,
|
||||||
}
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql, data)
|
await self.db.execute(merge_sql, data)
|
||||||
return left_data
|
return left_data
|
||||||
@@ -301,7 +307,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
# 转换精度
|
# 转换精度
|
||||||
dtype = str(embedding.dtype).upper()
|
dtype = str(embedding.dtype).upper()
|
||||||
dimension = embedding.shape[0]
|
dimension = embedding.shape[0]
|
||||||
embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
|
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
||||||
|
|
||||||
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
|
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
|
||||||
params = {
|
params = {
|
||||||
@@ -309,9 +315,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"better_than_threshold": self.cosine_better_than_threshold,
|
"better_than_threshold": self.cosine_better_than_threshold,
|
||||||
}
|
}
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
results = await self.db.query(SQL,params=params, multirows=True)
|
results = await self.db.query(SQL, params=params, multirows=True)
|
||||||
# print("vector search result:",results)
|
# print("vector search result:",results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -348,16 +354,16 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
content_vector = embeddings[0]
|
content_vector = embeddings[0]
|
||||||
merge_sql = SQL_TEMPLATES["merge_node"]
|
merge_sql = SQL_TEMPLATES["merge_node"]
|
||||||
data = {
|
data = {
|
||||||
"workspace":self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"name":entity_name,
|
"name": entity_name,
|
||||||
"entity_type":entity_type,
|
"entity_type": entity_type,
|
||||||
"description":description,
|
"description": description,
|
||||||
"source_chunk_id":source_id,
|
"source_chunk_id": source_id,
|
||||||
"content":content,
|
"content": content,
|
||||||
"content_vector":content_vector
|
"content_vector": content_vector,
|
||||||
}
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql,data)
|
await self.db.execute(merge_sql, data)
|
||||||
# self._graph.add_node(node_id, **node_data)
|
# self._graph.add_node(node_id, **node_data)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
@@ -371,7 +377,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
keywords = edge_data["keywords"]
|
keywords = edge_data["keywords"]
|
||||||
description = edge_data["description"]
|
description = edge_data["description"]
|
||||||
source_chunk_id = edge_data["source_id"]
|
source_chunk_id = edge_data["source_id"]
|
||||||
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
|
logger.debug(
|
||||||
|
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
|
||||||
|
)
|
||||||
|
|
||||||
content = keywords + source_name + target_name + description
|
content = keywords + source_name + target_name + description
|
||||||
contents = [content]
|
contents = [content]
|
||||||
@@ -386,18 +394,18 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
content_vector = embeddings[0]
|
content_vector = embeddings[0]
|
||||||
merge_sql = SQL_TEMPLATES["merge_edge"]
|
merge_sql = SQL_TEMPLATES["merge_edge"]
|
||||||
data = {
|
data = {
|
||||||
"workspace":self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"source_name":source_name,
|
"source_name": source_name,
|
||||||
"target_name":target_name,
|
"target_name": target_name,
|
||||||
"weight":weight,
|
"weight": weight,
|
||||||
"keywords":keywords,
|
"keywords": keywords,
|
||||||
"description":description,
|
"description": description,
|
||||||
"source_chunk_id":source_chunk_id,
|
"source_chunk_id": source_chunk_id,
|
||||||
"content":content,
|
"content": content,
|
||||||
"content_vector":content_vector
|
"content_vector": content_vector,
|
||||||
}
|
}
|
||||||
# print(merge_sql)
|
# print(merge_sql)
|
||||||
await self.db.execute(merge_sql,data)
|
await self.db.execute(merge_sql, data)
|
||||||
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||||
|
|
||||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||||
@@ -428,13 +436,10 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
"""根据节点id检查节点是否存在"""
|
"""根据节点id检查节点是否存在"""
|
||||||
SQL = SQL_TEMPLATES["has_node"]
|
SQL = SQL_TEMPLATES["has_node"]
|
||||||
params = {
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
"workspace":self.db.workspace,
|
|
||||||
"node_id":node_id
|
|
||||||
}
|
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
# print(self.db.workspace, node_id)
|
# print(self.db.workspace, node_id)
|
||||||
res = await self.db.query(SQL,params)
|
res = await self.db.query(SQL, params)
|
||||||
if res:
|
if res:
|
||||||
# print("Node exist!",res)
|
# print("Node exist!",res)
|
||||||
return True
|
return True
|
||||||
@@ -446,12 +451,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
"""根据源和目标节点id检查边是否存在"""
|
"""根据源和目标节点id检查边是否存在"""
|
||||||
SQL = SQL_TEMPLATES["has_edge"]
|
SQL = SQL_TEMPLATES["has_edge"]
|
||||||
params = {
|
params = {
|
||||||
"workspace":self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"source_node_id":source_node_id,
|
"source_node_id": source_node_id,
|
||||||
"target_node_id":target_node_id
|
"target_node_id": target_node_id,
|
||||||
}
|
}
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL,params)
|
res = await self.db.query(SQL, params)
|
||||||
if res:
|
if res:
|
||||||
# print("Edge exist!",res)
|
# print("Edge exist!",res)
|
||||||
return True
|
return True
|
||||||
@@ -462,12 +467,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""根据节点id获取节点的度"""
|
"""根据节点id获取节点的度"""
|
||||||
SQL = SQL_TEMPLATES["node_degree"]
|
SQL = SQL_TEMPLATES["node_degree"]
|
||||||
params = {
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
"workspace":self.db.workspace,
|
|
||||||
"node_id":node_id
|
|
||||||
}
|
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL,params)
|
res = await self.db.query(SQL, params)
|
||||||
if res:
|
if res:
|
||||||
# print("Node degree",res["degree"])
|
# print("Node degree",res["degree"])
|
||||||
return res["degree"]
|
return res["degree"]
|
||||||
@@ -484,13 +486,10 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
"""根据节点id获取节点数据"""
|
"""根据节点id获取节点数据"""
|
||||||
SQL = SQL_TEMPLATES["get_node"]
|
SQL = SQL_TEMPLATES["get_node"]
|
||||||
params = {
|
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||||
"workspace":self.db.workspace,
|
|
||||||
"node_id":node_id
|
|
||||||
}
|
|
||||||
# print(self.db.workspace, node_id)
|
# print(self.db.workspace, node_id)
|
||||||
# print(SQL)
|
# print(SQL)
|
||||||
res = await self.db.query(SQL,params)
|
res = await self.db.query(SQL, params)
|
||||||
if res:
|
if res:
|
||||||
# print("Get node!",self.db.workspace, node_id,res)
|
# print("Get node!",self.db.workspace, node_id,res)
|
||||||
return res
|
return res
|
||||||
@@ -504,11 +503,11 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
"""根据源和目标节点id获取边"""
|
"""根据源和目标节点id获取边"""
|
||||||
SQL = SQL_TEMPLATES["get_edge"]
|
SQL = SQL_TEMPLATES["get_edge"]
|
||||||
params = {
|
params = {
|
||||||
"workspace":self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"source_node_id":source_node_id,
|
"source_node_id": source_node_id,
|
||||||
"target_node_id":target_node_id
|
"target_node_id": target_node_id,
|
||||||
}
|
}
|
||||||
res = await self.db.query(SQL,params)
|
res = await self.db.query(SQL, params)
|
||||||
if res:
|
if res:
|
||||||
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
||||||
return res
|
return res
|
||||||
@@ -520,10 +519,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
"""根据节点id获取节点的所有边"""
|
"""根据节点id获取节点的所有边"""
|
||||||
if await self.has_node(source_node_id):
|
if await self.has_node(source_node_id):
|
||||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||||
params = {
|
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
||||||
"workspace":self.db.workspace,
|
|
||||||
"source_node_id":source_node_id
|
|
||||||
}
|
|
||||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
data = [(i["source_name"], i["target_name"]) for i in res]
|
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||||
@@ -536,18 +532,20 @@ class OracleGraphStorage(BaseGraphStorage):
|
|||||||
async def get_all_nodes(self, limit: int):
|
async def get_all_nodes(self, limit: int):
|
||||||
"""查询所有节点"""
|
"""查询所有节点"""
|
||||||
SQL = SQL_TEMPLATES["get_all_nodes"]
|
SQL = SQL_TEMPLATES["get_all_nodes"]
|
||||||
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
params = {"workspace": self.db.workspace, "limit": str(limit)}
|
||||||
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_all_edges(self, limit: int):
|
async def get_all_edges(self, limit: int):
|
||||||
"""查询所有边"""
|
"""查询所有边"""
|
||||||
SQL = SQL_TEMPLATES["get_all_edges"]
|
SQL = SQL_TEMPLATES["get_all_edges"]
|
||||||
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
params = {"workspace": self.db.workspace, "limit": str(limit)}
|
||||||
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||||
if res:
|
if res:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
N_T = {
|
N_T = {
|
||||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||||
@@ -719,18 +717,18 @@ SQL_TEMPLATES = {
|
|||||||
WHEN NOT MATCHED THEN
|
WHEN NOT MATCHED THEN
|
||||||
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
||||||
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
|
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
|
||||||
"get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
|
"get_all_nodes": """SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
|
||||||
FROM LIGHTRAG_GRAPH_NODES t1
|
FROM LIGHTRAG_GRAPH_NODES t1
|
||||||
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
||||||
WHERE t1.workspace=:workspace
|
WHERE t1.workspace=:workspace
|
||||||
order by t1.CREATETIME DESC
|
order by t1.CREATETIME DESC
|
||||||
fetch first :limit rows only
|
fetch first :limit rows only
|
||||||
""",
|
""",
|
||||||
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
"get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
||||||
t1.weight,t1.DESCRIPTION,t2.content
|
t1.weight,t1.DESCRIPTION,t2.content
|
||||||
FROM LIGHTRAG_GRAPH_EDGES t1
|
FROM LIGHTRAG_GRAPH_EDGES t1
|
||||||
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
||||||
WHERE t1.workspace=:workspace
|
WHERE t1.workspace=:workspace
|
||||||
order by t1.CREATETIME DESC
|
order by t1.CREATETIME DESC
|
||||||
fetch first :limit rows only"""
|
fetch first :limit rows only""",
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user