Update oracle_impl.py

This commit is contained in:
jin
2024-11-25 14:15:10 +08:00
parent 26ae240c65
commit 776ba2f2ce

View File

@@ -114,9 +114,7 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database") logger.info("Finished check all tables in Oracle database")
async def query( async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler connection.outputtypehandler = self.output_type_handler
@@ -175,11 +173,10 @@ class OracleKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> Union[dict, None]: async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据.""" """根据 id 获取 doc_full 数据."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
workspace=self.db.workspace, id=id params = {"workspace":self.db.workspace, "id":id}
)
# print("get_by_id:"+SQL) # print("get_by_id:"+SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
data = res # {"data":res} data = res # {"data":res}
# print (data) # print (data)
@@ -190,11 +187,11 @@ class OracleKVStorage(BaseKVStorage):
# Query by id # Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据""" """根据 id 获取 doc_chunks 数据"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) params = {"workspace":self.db.workspace}
)
#print("get_by_ids:"+SQL) #print("get_by_ids:"+SQL)
res = await self.db.query(SQL, multirows=True) #print(params)
res = await self.db.query(SQL,params, multirows=True)
if res: if res:
data = res # [{"data":i} for i in res] data = res # [{"data":i} for i in res]
# print(data) # print(data)
@@ -204,12 +201,16 @@ class OracleKVStorage(BaseKVStorage):
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容""" """过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]))
workspace=self.db.workspace, params = {"workspace":self.db.workspace}
ids=",".join([f"'{k}'" for k in keys]), try:
) await self.db.query(SQL, params)
res = await self.db.query(SQL, multirows=True) except Exception as e:
logger.error(f"Oracle database error: {e}")
print(SQL)
print(params)
res = await self.db.query(SQL, params,multirows=True)
data = None data = None
if res: if res:
exist_keys = [key["id"] for key in res] exist_keys = [key["id"] for key in res]
@@ -246,29 +247,31 @@ class OracleKVStorage(BaseKVStorage):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
# print(list_data) # print(list_data)
for item in list_data: for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) merge_sql = SQL_TEMPLATES["merge_chunk"]
data = {"check_id":item["__id__"],
values = [ "id":item["__id__"],
item["__id__"], "content":item["content"],
item["content"], "workspace":self.db.workspace,
self.db.workspace, "tokens":item["tokens"],
item["tokens"], "chunk_order_index":item["chunk_order_index"],
item["chunk_order_index"], "full_doc_id":item["full_doc_id"],
item["full_doc_id"], "content_vector":item["__vector__"]
item["__vector__"], }
]
# print(merge_sql) # print(merge_sql)
await self.db.execute(merge_sql, values) await self.db.execute(merge_sql, data)
if self.namespace == "full_docs": if self.namespace == "full_docs":
for k, v in self._data.items(): for k, v in self._data.items():
# values.clear() # values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"].format( merge_sql = SQL_TEMPLATES["merge_doc_full"]
check_id=k, data = {
) "check_id":k,
values = [k, self._data[k]["content"], self.db.workspace] "id":k,
"content":v["content"],
"workspace":self.db.workspace
}
# print(merge_sql) # print(merge_sql)
await self.db.execute(merge_sql, values) await self.db.execute(merge_sql, data)
return left_data return left_data
async def index_done_callback(self): async def index_done_callback(self):
@@ -298,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
# 转换精度 # 转换精度
dtype = str(embedding.dtype).upper() dtype = str(embedding.dtype).upper()
dimension = embedding.shape[0] dimension = embedding.shape[0]
embedding_string = ", ".join(map(str, embedding.tolist())) embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
SQL = SQL_TEMPLATES[self.namespace].format( SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
embedding_string=embedding_string, params = {
dimension=dimension, "embedding_string": embedding_string,
dtype=dtype, "workspace": self.db.workspace,
workspace=self.db.workspace, "top_k": top_k,
top_k=top_k, "better_than_threshold": self.cosine_better_than_threshold,
better_than_threshold=self.cosine_better_than_threshold, }
)
# print(SQL) # print(SQL)
results = await self.db.query(SQL, multirows=True) results = await self.db.query(SQL,params=params, multirows=True)
# print("vector search result:",results) # print("vector search result:",results)
return results return results
@@ -344,22 +346,18 @@ class OracleGraphStorage(BaseGraphStorage):
) )
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0] content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_node"].format( merge_sql = SQL_TEMPLATES["merge_node"]
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id data = {
) "workspace":self.db.workspace,
"name":entity_name,
"entity_type":entity_type,
"description":description,
"source_chunk_id":source_id,
"content":content,
"content_vector":content_vector
}
# print(merge_sql) # print(merge_sql)
await self.db.execute( await self.db.execute(merge_sql,data)
merge_sql,
[
self.db.workspace,
entity_name,
entity_type,
description,
source_id,
content,
content_vector,
],
)
# self._graph.add_node(node_id, **node_data) # self._graph.add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
@@ -373,6 +371,8 @@ class OracleGraphStorage(BaseGraphStorage):
keywords = edge_data["keywords"] keywords = edge_data["keywords"]
description = edge_data["description"] description = edge_data["description"]
source_chunk_id = edge_data["source_id"] source_chunk_id = edge_data["source_id"]
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
content = keywords + source_name + target_name + description content = keywords + source_name + target_name + description
contents = [content] contents = [content]
batches = [ batches = [
@@ -384,27 +384,20 @@ class OracleGraphStorage(BaseGraphStorage):
) )
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0] content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_edge"].format( merge_sql = SQL_TEMPLATES["merge_edge"]
workspace=self.db.workspace, data = {
source_name=source_name, "workspace":self.db.workspace,
target_name=target_name, "source_name":source_name,
source_chunk_id=source_chunk_id, "target_name":target_name,
) "weight":weight,
"keywords":keywords,
"description":description,
"source_chunk_id":source_chunk_id,
"content":content,
"content_vector":content_vector
}
# print(merge_sql) # print(merge_sql)
await self.db.execute( await self.db.execute(merge_sql,data)
merge_sql,
[
self.db.workspace,
source_name,
target_name,
weight,
keywords,
description,
source_chunk_id,
content,
content_vector,
],
)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -434,12 +427,14 @@ class OracleGraphStorage(BaseGraphStorage):
#################### query method ################# #################### query method #################
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在""" """根据节点id检查节点是否存在"""
SQL = SQL_TEMPLATES["has_node"].format( SQL = SQL_TEMPLATES["has_node"]
workspace=self.db.workspace, node_id=node_id params = {
) "workspace":self.db.workspace,
"node_id":node_id
}
# print(SQL) # print(SQL)
# print(self.db.workspace, node_id) # print(self.db.workspace, node_id)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
# print("Node exist!",res) # print("Node exist!",res)
return True return True
@@ -449,13 +444,14 @@ class OracleGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在""" """根据源和目标节点id检查边是否存在"""
SQL = SQL_TEMPLATES["has_edge"].format( SQL = SQL_TEMPLATES["has_edge"]
workspace=self.db.workspace, params = {
source_node_id=source_node_id, "workspace":self.db.workspace,
target_node_id=target_node_id, "source_node_id":source_node_id,
) "target_node_id":target_node_id
}
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
# print("Edge exist!",res) # print("Edge exist!",res)
return True return True
@@ -465,11 +461,13 @@ class OracleGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度""" """根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"].format( SQL = SQL_TEMPLATES["node_degree"]
workspace=self.db.workspace, node_id=node_id params = {
) "workspace":self.db.workspace,
"node_id":node_id
}
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
# print("Node degree",res["degree"]) # print("Node degree",res["degree"])
return res["degree"] return res["degree"]
@@ -485,12 +483,14 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
"""根据节点id获取节点数据""" """根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"].format( SQL = SQL_TEMPLATES["get_node"]
workspace=self.db.workspace, node_id=node_id params = {
) "workspace":self.db.workspace,
"node_id":node_id
}
# print(self.db.workspace, node_id) # print(self.db.workspace, node_id)
# print(SQL) # print(SQL)
res = await self.db.query(SQL) res = await self.db.query(SQL,params)
if res: if res:
# print("Get node!",self.db.workspace, node_id,res) # print("Get node!",self.db.workspace, node_id,res)
return res return res
@@ -502,12 +502,13 @@ class OracleGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> Union[dict, None]:
"""根据源和目标节点id获取边""" """根据源和目标节点id获取边"""
SQL = SQL_TEMPLATES["get_edge"].format( SQL = SQL_TEMPLATES["get_edge"]
workspace=self.db.workspace, params = {
source_node_id=source_node_id, "workspace":self.db.workspace,
target_node_id=target_node_id, "source_node_id":source_node_id,
) "target_node_id":target_node_id
res = await self.db.query(SQL) }
res = await self.db.query(SQL,params)
if res: if res:
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
return res return res
@@ -518,10 +519,12 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node_edges(self, source_node_id: str): async def get_node_edges(self, source_node_id: str):
"""根据节点id获取节点的所有边""" """根据节点id获取节点的所有边"""
if await self.has_node(source_node_id): if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"].format( SQL = SQL_TEMPLATES["get_node_edges"]
workspace=self.db.workspace, source_node_id=source_node_id params = {
) "workspace":self.db.workspace,
res = await self.db.query(sql=SQL, multirows=True) "source_node_id":source_node_id
}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res: if res:
data = [(i["source_name"], i["target_name"]) for i in res] data = [(i["source_name"], i["target_name"]) for i in res]
# print("Get node edge!",self.db.workspace, source_node_id,data) # print("Get node edge!",self.db.workspace, source_node_id,data)
@@ -530,6 +533,28 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Node Edge not exist!",self.db.workspace, source_node_id) # print("Node Edge not exist!",self.db.workspace, source_node_id)
return [] return []
async def get_all_nodes(self, limit: int):
"""查询所有节点"""
SQL = SQL_TEMPLATES["get_all_nodes"]
params = {"workspace":self.db.workspace, "limit":str(limit)}
res = await self.db.query(sql=SQL,params=params, multirows=True)
if res:
return res
async def get_all_edges(self, limit: int):
"""查询所有边"""
SQL = SQL_TEMPLATES["get_all_edges"]
params = {"workspace":self.db.workspace, "limit":str(limit)}
res = await self.db.query(sql=SQL,params=params, multirows=True)
if res:
return res
async def get_statistics(self):
SQL = SQL_TEMPLATES["get_statistics"]
params = {"workspace":self.db.workspace}
res = await self.db.query(sql=SQL,params=params, multirows=True)
if res:
return res
N_T = { N_T = {
"full_docs": "LIGHTRAG_DOC_FULL", "full_docs": "LIGHTRAG_DOC_FULL",
@@ -701,5 +726,37 @@ SQL_TEMPLATES = {
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
"get_all_nodes":"""WITH t0 AS (
SELECT name AS id, entity_type AS label, entity_type, description,
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
FROM lightrag_graph_nodes
WHERE workspace = :workspace
ORDER BY createtime DESC fetch first :limit rows only
), t1 AS (
SELECT t0.id, source_chunk_id
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
), t2 AS (
SELECT t1.id, LISTAGG(t2.content, '\n') content
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
GROUP BY t1.id
)
SELECT t0.id, label, entity_type, description, t2.content
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
t1.weight,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_EDGES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
fetch first :limit rows only""",
"get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
FROM (
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
UNION
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
)""",
} }