From 8e16f0815ce75a81797fb45d2f755204af9cd638 Mon Sep 17 00:00:00 2001 From: tmuife <43266626@qq.com> Date: Mon, 18 Nov 2024 10:00:06 +0800 Subject: [PATCH 1/2] change the type of binding parameters in Oracle23AI --- lightrag/kg/oracle_impl.py | 320 +++++++++++++++++++++---------------- 1 file changed, 178 insertions(+), 142 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 96a9e795..fd8bf536 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -114,16 +114,17 @@ class OracleDB: logger.info("Finished check all tables in Oracle database") - async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: + async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]: async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler with connection.cursor() as cursor: try: - await cursor.execute(sql) + await cursor.execute(sql, params) except Exception as e: logger.error(f"Oracle database error: {e}") print(sql) + print(params) raise columns = [column[0].lower() for column in cursor.description] if multirows: @@ -140,7 +141,7 @@ class OracleDB: data = None return data - async def execute(self, sql: str, data: list = None): + async def execute(self, sql: str, data: list | dict = None): # logger.info("go into OracleDB execute method") try: async with self.pool.acquire() as connection: @@ -172,11 +173,10 @@ class OracleKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> Union[dict, None]: """根据 id 获取 doc_full 数据.""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( - workspace=self.db.workspace, id=id - ) + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] + params = {"workspace":self.db.workspace, "id":id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: data = res # {"data":res} # print (data) @@ -187,11 +187,11 @@ class OracleKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """根据 id 获取 doc_chunks 数据""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) - ) - # print("get_by_ids:"+SQL) - res = await self.db.query(SQL, multirows=True) + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids])) + params = {"workspace":self.db.workspace} + #print("get_by_ids:"+SQL) + #print(params) + res = await self.db.query(SQL,params, multirows=True) if res: data = res # [{"data":i} for i in res] # print(data) @@ -201,12 +201,16 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" - SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=N_T[self.namespace], - workspace=self.db.workspace, - ids=",".join([f"'{k}'" for k in keys]), - ) - res = await self.db.query(SQL, multirows=True) + SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], + ids=",".join([f"'{id}'" for id in keys])) + params = {"workspace":self.db.workspace} + try: + await self.db.query(SQL, params) + except Exception as e: + logger.error(f"Oracle database error: {e}") + print(SQL) + print(params) + res = await self.db.query(SQL, params,multirows=True) data = None if res: exist_keys = [key["id"] for key in res] @@ -243,29 +247,31 @@ class OracleKVStorage(BaseKVStorage): d["__vector__"] = embeddings[i] # print(list_data) for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) - - values = [ - item["__id__"], - item["content"], - self.db.workspace, - item["tokens"], - item["chunk_order_index"], - item["full_doc_id"], - item["__vector__"], - ] + merge_sql = SQL_TEMPLATES["merge_chunk"] + data = {"check_id":item["__id__"], + "id":item["__id__"], + "content":item["content"], + "workspace":self.db.workspace, + "tokens":item["tokens"], + "chunk_order_index":item["chunk_order_index"], + "full_doc_id":item["full_doc_id"], + "content_vector":item["__vector__"] + } # print(merge_sql) - await self.db.execute(merge_sql, values) + await self.db.execute(merge_sql, data) if self.namespace == "full_docs": for k, v in self._data.items(): # values.clear() - merge_sql = SQL_TEMPLATES["merge_doc_full"].format( - check_id=k, - ) - values = [k, self._data[k]["content"], self.db.workspace] + merge_sql = SQL_TEMPLATES["merge_doc_full"] + data = { + "check_id":k, + "id":k, + "content":v["content"], + "workspace":self.db.workspace + } # print(merge_sql) - await self.db.execute(merge_sql, values) + await self.db.execute(merge_sql, data) return left_data async def index_done_callback(self): @@ -295,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage): # 转换精度 dtype = str(embedding.dtype).upper() dimension = embedding.shape[0] - embedding_string = ", ".join(map(str, embedding.tolist())) + embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]" - SQL = SQL_TEMPLATES[self.namespace].format( - embedding_string=embedding_string, - dimension=dimension, - dtype=dtype, - workspace=self.db.workspace, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) + SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype) + params = { + "embedding_string": embedding_string, + "workspace": self.db.workspace, + "top_k": top_k, + "better_than_threshold": self.cosine_better_than_threshold, + } # print(SQL) - results = await self.db.query(SQL, multirows=True) + results = await self.db.query(SQL,params=params, multirows=True) # print("vector search result:",results) return results @@ -328,6 +333,8 @@ class OracleGraphStorage(BaseGraphStorage): entity_type = node_data["entity_type"] description = node_data["description"] source_id = node_data["source_id"] + logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") + content = entity_name + description contents = [content] batches = [ @@ -339,22 +346,18 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_node"].format( - workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id - ) + merge_sql = SQL_TEMPLATES["merge_node"] + data = { + "workspace":self.db.workspace, + "name":entity_name, + "entity_type":entity_type, + "description":description, + "source_chunk_id":source_id, + "content":content, + "content_vector":content_vector + } # print(merge_sql) - await self.db.execute( - merge_sql, - [ - self.db.workspace, - entity_name, - entity_type, - description, - source_id, - content, - content_vector, - ], - ) + await self.db.execute(merge_sql,data) # self._graph.add_node(node_id, **node_data) async def upsert_edge( @@ -368,6 +371,8 @@ class OracleGraphStorage(BaseGraphStorage): keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] + logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}") + content = keywords + source_name + target_name + description contents = [content] batches = [ @@ -379,27 +384,20 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_edge"].format( - workspace=self.db.workspace, - source_name=source_name, - target_name=target_name, - source_chunk_id=source_chunk_id, - ) + merge_sql = SQL_TEMPLATES["merge_edge"] + data = { + "workspace":self.db.workspace, + "source_name":source_name, + "target_name":target_name, + "weight":weight, + "keywords":keywords, + "description":description, + "source_chunk_id":source_chunk_id, + "content":content, + "content_vector":content_vector + } # print(merge_sql) - await self.db.execute( - merge_sql, - [ - self.db.workspace, - source_name, - target_name, - weight, - keywords, - description, - source_chunk_id, - content, - content_vector, - ], - ) + await self.db.execute(merge_sql,data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: @@ -429,12 +427,14 @@ class OracleGraphStorage(BaseGraphStorage): #################### query method ################# async def has_node(self, node_id: str) -> bool: """根据节点id检查节点是否存在""" - SQL = SQL_TEMPLATES["has_node"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["has_node"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(SQL) # print(self.db.workspace, node_id) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Node exist!",res) return True @@ -444,13 +444,14 @@ class OracleGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """根据源和目标节点id检查边是否存在""" - SQL = SQL_TEMPLATES["has_edge"].format( - workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id, - ) + SQL = SQL_TEMPLATES["has_edge"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id, + "target_node_id":target_node_id + } # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Edge exist!",res) return True @@ -460,11 +461,13 @@ class OracleGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """根据节点id获取节点的度""" - SQL = SQL_TEMPLATES["node_degree"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["node_degree"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Node degree",res["degree"]) return res["degree"] @@ -480,12 +483,14 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: """根据节点id获取节点数据""" - SQL = SQL_TEMPLATES["get_node"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["get_node"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(self.db.workspace, node_id) # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Get node!",self.db.workspace, node_id,res) return res @@ -497,12 +502,13 @@ class OracleGraphStorage(BaseGraphStorage): self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """根据源和目标节点id获取边""" - SQL = SQL_TEMPLATES["get_edge"].format( - workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id, - ) - res = await self.db.query(SQL) + SQL = SQL_TEMPLATES["get_edge"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id, + "target_node_id":target_node_id + } + res = await self.db.query(SQL,params) if res: # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) return res @@ -513,10 +519,12 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node_edges(self, source_node_id: str): """根据节点id获取节点的所有边""" if await self.has_node(source_node_id): - SQL = SQL_TEMPLATES["get_node_edges"].format( - workspace=self.db.workspace, source_node_id=source_node_id - ) - res = await self.db.query(sql=SQL, multirows=True) + SQL = SQL_TEMPLATES["get_node_edges"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id + } + res = await self.db.query(sql=SQL, params=params, multirows=True) if res: data = [(i["source_name"], i["target_name"]) for i in res] # print("Get node edge!",self.db.workspace, source_node_id,data) @@ -524,8 +532,22 @@ class OracleGraphStorage(BaseGraphStorage): else: # print("Node Edge not exist!",self.db.workspace, source_node_id) return [] + + async def get_all_nodes(self, limit: int): + """查询所有节点""" + SQL = SQL_TEMPLATES["get_all_nodes"] + params = {"workspace":self.db.workspace, "limit":str(limit)} + res = await self.db.query(sql=SQL,params=params, multirows=True) + if res: + return res - + async def get_all_edges(self, limit: int): + """查询所有边""" + SQL = SQL_TEMPLATES["get_all_edges"] + params = {"workspace":self.db.workspace, "limit":str(limit)} + res = await self.db.query(sql=SQL,params=params, multirows=True) + if res: + return res N_T = { "full_docs": "LIGHTRAG_DOC_FULL", "text_chunks": "LIGHTRAG_DOC_CHUNKS", @@ -619,82 +641,96 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage - "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", - "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", - "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", - "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})", + "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL - ON (a.id = '{check_id}') + ON (a.id = :check_id) WHEN NOT MATCHED THEN - INSERT(id,content,workspace) values(:1,:2,:3) + INSERT(id,content,workspace) values(:id,:content,:workspace) """, "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a USING DUAL - ON (a.id = '{check_id}') + ON (a.id = :check_id) WHEN NOT MATCHED THEN INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:1,:2,:3,:4,:5,:6,:7) """, + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, # SQL for VectorStorage "entities": """SELECT name as entity_name FROM - (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') - WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace) + WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM - (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') - WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace) + WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", "chunks": """SELECT id FROM - (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') - WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace) + WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", # SQL for GraphStorage "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph MATCH (a) - WHERE a.workspace='{workspace}' AND a.name='{node_id}' + WHERE a.workspace=:workspace AND a.name=:node_id COLUMNS (a.name))""", "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph MATCH (a) -[e]-> (b) - WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' - AND a.name='{source_node_id}' AND b.name='{target_node_id}' + WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace + AND a.name=:source_node_id AND b.name=:target_node_id COLUMNS (e.source_name,e.target_name) )""", "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) - WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' - AND a.name='{node_id}' or b.name = '{node_id}' + WHERE a.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace + AND a.name=:node_id or b.name = :node_id COLUMNS (a.name))""", "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description FROM GRAPH_TABLE (lightrag_graph MATCH (a) - WHERE a.workspace='{workspace}' AND a.name='{node_id}' + WHERE a.workspace=:workspace AND a.name=:node_id COLUMNS (a.name) ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name - WHERE t2.workspace='{workspace}'""", + WHERE t2.workspace=:workspace""", "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) - WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' - AND a.name='{source_node_id}' and b.name = '{target_node_id}' + WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace + AND a.name=:source_node_id and b.name = :target_node_id COLUMNS (e.id,a.name as source_id) ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", "get_node_edges": """SELECT source_name,target_name FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) - WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' - AND a.name='{source_node_id}' + WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace + AND a.name=:source_node_id COLUMNS (a.name as source_name,b.name as target_name))""", "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL - ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}') + ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id) WHEN NOT MATCHED THEN INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) - values (:1,:2,:3,:4,:5,:6,:7) """, + values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """, "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL - ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}') + ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, + values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, + "get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content + FROM LIGHTRAG_GRAPH_NODES t1 + LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id + WHERE t1.workspace=:workspace + order by t1.CREATETIME DESC + fetch first :limit rows only + """, + "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, + t1.weight,t1.DESCRIPTION,t2.content + FROM LIGHTRAG_GRAPH_EDGES t1 + LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id + WHERE t1.workspace=:workspace + order by t1.CREATETIME DESC + fetch first :limit rows only""" } From 89d7967349fe30c1bf682935ce95d54082b7d9fe Mon Sep 17 00:00:00 2001 From: tmuife <43266626@qq.com> Date: Mon, 18 Nov 2024 13:52:49 +0800 Subject: [PATCH 2/2] use pre-commit reformat --- lightrag/kg/oracle_impl.py | 176 ++++++++++++++++++------------------- 1 file changed, 87 insertions(+), 89 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index fd8bf536..b46d36d8 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -114,7 +114,9 @@ class OracleDB: logger.info("Finished check all tables in Oracle database") - async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]: + async def query( + self, sql: str, params: dict = None, multirows: bool = False + ) -> Union[dict, None]: async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler @@ -174,9 +176,9 @@ class OracleKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> Union[dict, None]: """根据 id 获取 doc_full 数据.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"workspace":self.db.workspace, "id":id} + params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: data = res # {"data":res} # print (data) @@ -187,11 +189,13 @@ class OracleKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """根据 id 获取 doc_chunks 数据""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids])) - params = {"workspace":self.db.workspace} - #print("get_by_ids:"+SQL) - #print(params) - res = await self.db.query(SQL,params, multirows=True) + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + # print("get_by_ids:"+SQL) + # print(params) + res = await self.db.query(SQL, params, multirows=True) if res: data = res # [{"data":i} for i in res] # print(data) @@ -201,16 +205,17 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" - SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], - ids=",".join([f"'{id}'" for id in keys])) - params = {"workspace":self.db.workspace} + SQL = SQL_TEMPLATES["filter_keys"].format( + table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) + ) + params = {"workspace": self.db.workspace} try: await self.db.query(SQL, params) except Exception as e: logger.error(f"Oracle database error: {e}") print(SQL) print(params) - res = await self.db.query(SQL, params,multirows=True) + res = await self.db.query(SQL, params, multirows=True) data = None if res: exist_keys = [key["id"] for key in res] @@ -248,15 +253,16 @@ class OracleKVStorage(BaseKVStorage): # print(list_data) for item in list_data: merge_sql = SQL_TEMPLATES["merge_chunk"] - data = {"check_id":item["__id__"], - "id":item["__id__"], - "content":item["content"], - "workspace":self.db.workspace, - "tokens":item["tokens"], - "chunk_order_index":item["chunk_order_index"], - "full_doc_id":item["full_doc_id"], - "content_vector":item["__vector__"] - } + data = { + "check_id": item["__id__"], + "id": item["__id__"], + "content": item["content"], + "workspace": self.db.workspace, + "tokens": item["tokens"], + "chunk_order_index": item["chunk_order_index"], + "full_doc_id": item["full_doc_id"], + "content_vector": item["__vector__"], + } # print(merge_sql) await self.db.execute(merge_sql, data) @@ -265,11 +271,11 @@ class OracleKVStorage(BaseKVStorage): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] data = { - "check_id":k, - "id":k, - "content":v["content"], - "workspace":self.db.workspace - } + "check_id": k, + "id": k, + "content": v["content"], + "workspace": self.db.workspace, + } # print(merge_sql) await self.db.execute(merge_sql, data) return left_data @@ -301,7 +307,7 @@ class OracleVectorDBStorage(BaseVectorStorage): # 转换精度 dtype = str(embedding.dtype).upper() dimension = embedding.shape[0] - embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]" + embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype) params = { @@ -309,9 +315,9 @@ class OracleVectorDBStorage(BaseVectorStorage): "workspace": self.db.workspace, "top_k": top_k, "better_than_threshold": self.cosine_better_than_threshold, - } + } # print(SQL) - results = await self.db.query(SQL,params=params, multirows=True) + results = await self.db.query(SQL, params=params, multirows=True) # print("vector search result:",results) return results @@ -346,18 +352,18 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_node"] + merge_sql = SQL_TEMPLATES["merge_node"] data = { - "workspace":self.db.workspace, - "name":entity_name, - "entity_type":entity_type, - "description":description, - "source_chunk_id":source_id, - "content":content, - "content_vector":content_vector - } + "workspace": self.db.workspace, + "name": entity_name, + "entity_type": entity_type, + "description": description, + "source_chunk_id": source_id, + "content": content, + "content_vector": content_vector, + } # print(merge_sql) - await self.db.execute(merge_sql,data) + await self.db.execute(merge_sql, data) # self._graph.add_node(node_id, **node_data) async def upsert_edge( @@ -371,7 +377,9 @@ class OracleGraphStorage(BaseGraphStorage): keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] - logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}") + logger.debug( + f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" + ) content = keywords + source_name + target_name + description contents = [content] @@ -384,20 +392,20 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_edge"] + merge_sql = SQL_TEMPLATES["merge_edge"] data = { - "workspace":self.db.workspace, - "source_name":source_name, - "target_name":target_name, - "weight":weight, - "keywords":keywords, - "description":description, - "source_chunk_id":source_chunk_id, - "content":content, - "content_vector":content_vector - } + "workspace": self.db.workspace, + "source_name": source_name, + "target_name": target_name, + "weight": weight, + "keywords": keywords, + "description": description, + "source_chunk_id": source_chunk_id, + "content": content, + "content_vector": content_vector, + } # print(merge_sql) - await self.db.execute(merge_sql,data) + await self.db.execute(merge_sql, data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: @@ -428,13 +436,10 @@ class OracleGraphStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: """根据节点id检查节点是否存在""" SQL = SQL_TEMPLATES["has_node"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + params = {"workspace": self.db.workspace, "node_id": node_id} # print(SQL) # print(self.db.workspace, node_id) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: # print("Node exist!",res) return True @@ -446,12 +451,12 @@ class OracleGraphStorage(BaseGraphStorage): """根据源和目标节点id检查边是否存在""" SQL = SQL_TEMPLATES["has_edge"] params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id, - "target_node_id":target_node_id - } + "workspace": self.db.workspace, + "source_node_id": source_node_id, + "target_node_id": target_node_id, + } # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: # print("Edge exist!",res) return True @@ -462,12 +467,9 @@ class OracleGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """根据节点id获取节点的度""" SQL = SQL_TEMPLATES["node_degree"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + params = {"workspace": self.db.workspace, "node_id": node_id} # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: # print("Node degree",res["degree"]) return res["degree"] @@ -484,13 +486,10 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: """根据节点id获取节点数据""" SQL = SQL_TEMPLATES["get_node"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + params = {"workspace": self.db.workspace, "node_id": node_id} # print(self.db.workspace, node_id) # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: # print("Get node!",self.db.workspace, node_id,res) return res @@ -504,11 +503,11 @@ class OracleGraphStorage(BaseGraphStorage): """根据源和目标节点id获取边""" SQL = SQL_TEMPLATES["get_edge"] params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id, - "target_node_id":target_node_id - } - res = await self.db.query(SQL,params) + "workspace": self.db.workspace, + "source_node_id": source_node_id, + "target_node_id": target_node_id, + } + res = await self.db.query(SQL, params) if res: # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) return res @@ -520,10 +519,7 @@ class OracleGraphStorage(BaseGraphStorage): """根据节点id获取节点的所有边""" if await self.has_node(source_node_id): SQL = SQL_TEMPLATES["get_node_edges"] - params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id - } + params = {"workspace": self.db.workspace, "source_node_id": source_node_id} res = await self.db.query(sql=SQL, params=params, multirows=True) if res: data = [(i["source_name"], i["target_name"]) for i in res] @@ -532,22 +528,24 @@ class OracleGraphStorage(BaseGraphStorage): else: # print("Node Edge not exist!",self.db.workspace, source_node_id) return [] - + async def get_all_nodes(self, limit: int): """查询所有节点""" SQL = SQL_TEMPLATES["get_all_nodes"] - params = {"workspace":self.db.workspace, "limit":str(limit)} - res = await self.db.query(sql=SQL,params=params, multirows=True) + params = {"workspace": self.db.workspace, "limit": str(limit)} + res = await self.db.query(sql=SQL, params=params, multirows=True) if res: return res async def get_all_edges(self, limit: int): """查询所有边""" SQL = SQL_TEMPLATES["get_all_edges"] - params = {"workspace":self.db.workspace, "limit":str(limit)} - res = await self.db.query(sql=SQL,params=params, multirows=True) + params = {"workspace": self.db.workspace, "limit": str(limit)} + res = await self.db.query(sql=SQL, params=params, multirows=True) if res: return res + + N_T = { "full_docs": "LIGHTRAG_DOC_FULL", "text_chunks": "LIGHTRAG_DOC_CHUNKS", @@ -719,18 +717,18 @@ SQL_TEMPLATES = { WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, - "get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content + "get_all_nodes": """SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content FROM LIGHTRAG_GRAPH_NODES t1 LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id WHERE t1.workspace=:workspace order by t1.CREATETIME DESC fetch first :limit rows only """, - "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, + "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, t1.weight,t1.DESCRIPTION,t2.content FROM LIGHTRAG_GRAPH_EDGES t1 LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id WHERE t1.workspace=:workspace order by t1.CREATETIME DESC - fetch first :limit rows only""" + fetch first :limit rows only""", }