From d6589684ef4427be6544ed4d7c1274a68147d9bc Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Fri, 15 Nov 2024 12:57:01 +0800 Subject: [PATCH 1/8] use oracle bind variables to avoid error --- lightrag/base.py | 1 + lightrag/kg/oracle_impl.py | 316 ++++++++++++++++++++----------------- lightrag/operate.py | 16 +- lightrag/utils.py | 6 +- 4 files changed, 193 insertions(+), 146 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 46dfc800..ca46057f 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -17,6 +17,7 @@ T = TypeVar("T") class QueryParam: mode: Literal["local", "global", "hybrid", "naive"] = "global" only_need_context: bool = False + only_need_prompt: bool = False response_type: str = "Multiple Paragraphs" # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. top_k: int = 60 diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 96a9e795..e81c32d0 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -114,16 +114,17 @@ class OracleDB: logger.info("Finished check all tables in Oracle database") - async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: + async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]: async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler with connection.cursor() as cursor: try: - await cursor.execute(sql) + await cursor.execute(sql, params) except Exception as e: logger.error(f"Oracle database error: {e}") print(sql) + print(params) raise columns = [column[0].lower() for column in cursor.description] if multirows: @@ -140,7 +141,7 @@ class OracleDB: data = None return data - async def execute(self, sql: str, data: list = None): + async def execute(self, sql: str, data: list | dict = None): # logger.info("go into OracleDB execute method") try: async with self.pool.acquire() as connection: @@ -172,11 +173,10 @@ class OracleKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> Union[dict, None]: """根据 id 获取 doc_full 数据.""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( - workspace=self.db.workspace, id=id - ) + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] + params = {"workspace":self.db.workspace, "id":id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: data = res # {"data":res} # print (data) @@ -187,11 +187,11 @@ class OracleKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """根据 id 获取 doc_chunks 数据""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) - ) - # print("get_by_ids:"+SQL) - res = await self.db.query(SQL, multirows=True) + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids])) + params = {"workspace":self.db.workspace} + #print("get_by_ids:"+SQL) + #print(params) + res = await self.db.query(SQL,params, multirows=True) if res: data = res # [{"data":i} for i in res] # print(data) @@ -201,12 +201,16 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" - SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=N_T[self.namespace], - workspace=self.db.workspace, - ids=",".join([f"'{k}'" for k in keys]), - ) - res = await self.db.query(SQL, multirows=True) + SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], + ids=",".join([f"'{id}'" for id in keys])) + params = {"workspace":self.db.workspace} + try: + await self.db.query(SQL, params) + except Exception as e: + logger.error(f"Oracle database error: {e}") + print(SQL) + print(params) + res = await self.db.query(SQL, params,multirows=True) data = None if res: exist_keys = [key["id"] for key in res] @@ -243,29 +247,31 @@ class OracleKVStorage(BaseKVStorage): d["__vector__"] = embeddings[i] # print(list_data) for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) - - values = [ - item["__id__"], - item["content"], - self.db.workspace, - item["tokens"], - item["chunk_order_index"], - item["full_doc_id"], - item["__vector__"], - ] + merge_sql = SQL_TEMPLATES["merge_chunk"] + data = {"check_id":item["__id__"], + "id":item["__id__"], + "content":item["content"], + "workspace":self.db.workspace, + "tokens":item["tokens"], + "chunk_order_index":item["chunk_order_index"], + "full_doc_id":item["full_doc_id"], + "content_vector":item["__vector__"] + } # print(merge_sql) - await self.db.execute(merge_sql, values) + await self.db.execute(merge_sql, data) if self.namespace == "full_docs": for k, v in self._data.items(): # values.clear() - merge_sql = SQL_TEMPLATES["merge_doc_full"].format( - check_id=k, - ) - values = [k, self._data[k]["content"], self.db.workspace] + merge_sql = SQL_TEMPLATES["merge_doc_full"] + data = { + "check_id":k, + "id":k, + "content":v["content"], + "workspace":self.db.workspace + } # print(merge_sql) - await self.db.execute(merge_sql, values) + await self.db.execute(merge_sql, data) return left_data async def index_done_callback(self): @@ -295,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage): # 转换精度 dtype = str(embedding.dtype).upper() dimension = embedding.shape[0] - embedding_string = ", ".join(map(str, embedding.tolist())) + embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]" - SQL = SQL_TEMPLATES[self.namespace].format( - embedding_string=embedding_string, - dimension=dimension, - dtype=dtype, - workspace=self.db.workspace, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) + SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype) + params = { + "embedding_string": embedding_string, + "workspace": self.db.workspace, + "top_k": top_k, + "better_than_threshold": self.cosine_better_than_threshold, + } # print(SQL) - results = await self.db.query(SQL, multirows=True) + results = await self.db.query(SQL,params=params, multirows=True) # print("vector search result:",results) return results @@ -339,22 +344,18 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_node"].format( - workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id - ) + merge_sql = SQL_TEMPLATES["merge_node"] + data = { + "workspace":self.db.workspace, + "name":entity_name, + "entity_type":entity_type, + "description":description, + "source_chunk_id":source_id, + "content":content, + "content_vector":content_vector + } # print(merge_sql) - await self.db.execute( - merge_sql, - [ - self.db.workspace, - entity_name, - entity_type, - description, - source_id, - content, - content_vector, - ], - ) + await self.db.execute(merge_sql,data) # self._graph.add_node(node_id, **node_data) async def upsert_edge( @@ -379,27 +380,20 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_edge"].format( - workspace=self.db.workspace, - source_name=source_name, - target_name=target_name, - source_chunk_id=source_chunk_id, - ) + merge_sql = SQL_TEMPLATES["merge_edge"] + data = { + "workspace":self.db.workspace, + "source_name":source_name, + "target_name":target_name, + "weight":weight, + "keywords":keywords, + "description":description, + "source_chunk_id":source_chunk_id, + "content":content, + "content_vector":content_vector + } # print(merge_sql) - await self.db.execute( - merge_sql, - [ - self.db.workspace, - source_name, - target_name, - weight, - keywords, - description, - source_chunk_id, - content, - content_vector, - ], - ) + await self.db.execute(merge_sql,data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: @@ -429,12 +423,14 @@ class OracleGraphStorage(BaseGraphStorage): #################### query method ################# async def has_node(self, node_id: str) -> bool: """根据节点id检查节点是否存在""" - SQL = SQL_TEMPLATES["has_node"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["has_node"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(SQL) # print(self.db.workspace, node_id) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Node exist!",res) return True @@ -444,13 +440,14 @@ class OracleGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """根据源和目标节点id检查边是否存在""" - SQL = SQL_TEMPLATES["has_edge"].format( - workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id, - ) + SQL = SQL_TEMPLATES["has_edge"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id, + "target_node_id":target_node_id + } # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Edge exist!",res) return True @@ -460,11 +457,13 @@ class OracleGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """根据节点id获取节点的度""" - SQL = SQL_TEMPLATES["node_degree"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["node_degree"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Node degree",res["degree"]) return res["degree"] @@ -480,12 +479,14 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: """根据节点id获取节点数据""" - SQL = SQL_TEMPLATES["get_node"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["get_node"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(self.db.workspace, node_id) # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Get node!",self.db.workspace, node_id,res) return res @@ -497,12 +498,13 @@ class OracleGraphStorage(BaseGraphStorage): self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """根据源和目标节点id获取边""" - SQL = SQL_TEMPLATES["get_edge"].format( - workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id, - ) - res = await self.db.query(SQL) + SQL = SQL_TEMPLATES["get_edge"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id, + "target_node_id":target_node_id + } + res = await self.db.query(SQL,params) if res: # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) return res @@ -513,10 +515,12 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node_edges(self, source_node_id: str): """根据节点id获取节点的所有边""" if await self.has_node(source_node_id): - SQL = SQL_TEMPLATES["get_node_edges"].format( - workspace=self.db.workspace, source_node_id=source_node_id - ) - res = await self.db.query(sql=SQL, multirows=True) + SQL = SQL_TEMPLATES["get_node_edges"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id + } + res = await self.db.query(sql=SQL, params=params, multirows=True) if res: data = [(i["source_name"], i["target_name"]) for i in res] # print("Get node edge!",self.db.workspace, source_node_id,data) @@ -524,8 +528,22 @@ class OracleGraphStorage(BaseGraphStorage): else: # print("Node Edge not exist!",self.db.workspace, source_node_id) return [] + + async def get_all_nodes(self, limit: int): + """查询所有节点""" + SQL = SQL_TEMPLATES["get_all_nodes"] + params = {"workspace":self.db.workspace, "limit":str(limit)} + res = await self.db.query(sql=SQL,params=params, multirows=True) + if res: + return res - + async def get_all_edges(self, limit: int): + """查询所有边""" + SQL = SQL_TEMPLATES["get_all_edges"] + params = {"workspace":self.db.workspace, "limit":str(limit)} + res = await self.db.query(sql=SQL,params=params, multirows=True) + if res: + return res N_T = { "full_docs": "LIGHTRAG_DOC_FULL", "text_chunks": "LIGHTRAG_DOC_CHUNKS", @@ -619,82 +637,96 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage - "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", - "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", - "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", - "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})", + "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL - ON (a.id = '{check_id}') + ON (a.id = :check_id) WHEN NOT MATCHED THEN - INSERT(id,content,workspace) values(:1,:2,:3) + INSERT(id,content,workspace) values(:id,:content,:workspace) """, "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a USING DUAL - ON (a.id = '{check_id}') + ON (a.id = :check_id) WHEN NOT MATCHED THEN INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:1,:2,:3,:4,:5,:6,:7) """, + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, # SQL for VectorStorage "entities": """SELECT name as entity_name FROM - (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') - WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace) + WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM - (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') - WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace) + WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", "chunks": """SELECT id FROM - (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') - WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", + (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace) + WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", # SQL for GraphStorage "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph MATCH (a) - WHERE a.workspace='{workspace}' AND a.name='{node_id}' + WHERE a.workspace=:workspace AND a.name=:node_id COLUMNS (a.name))""", "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph MATCH (a) -[e]-> (b) - WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' - AND a.name='{source_node_id}' AND b.name='{target_node_id}' + WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace + AND a.name=:source_node_id AND b.name=:target_node_id COLUMNS (e.source_name,e.target_name) )""", "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) - WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' - AND a.name='{node_id}' or b.name = '{node_id}' + WHERE a.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace + AND a.name=:node_id or b.name = :node_id COLUMNS (a.name))""", "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description FROM GRAPH_TABLE (lightrag_graph MATCH (a) - WHERE a.workspace='{workspace}' AND a.name='{node_id}' + WHERE a.workspace=:workspace AND a.name=:node_id COLUMNS (a.name) ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name - WHERE t2.workspace='{workspace}'""", + WHERE t2.workspace=:workspace""", "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) - WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' - AND a.name='{source_node_id}' and b.name = '{target_node_id}' + WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace + AND a.name=:source_node_id and b.name = :target_node_id COLUMNS (e.id,a.name as source_id) ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", "get_node_edges": """SELECT source_name,target_name FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) - WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' - AND a.name='{source_node_id}' + WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace + AND a.name=:source_node_id COLUMNS (a.name as source_name,b.name as target_name))""", "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL - ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}') + ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id) WHEN NOT MATCHED THEN INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) - values (:1,:2,:3,:4,:5,:6,:7) """, + values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """, "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL - ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}') + ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, + values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, + "get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content + FROM LIGHTRAG_GRAPH_NODES t1 + LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id + WHERE t1.workspace=:workspace + order by t1.CREATETIME DESC + fetch first :limit rows only + """, + "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, + t1.weight,t1.DESCRIPTION,t2.content + FROM LIGHTRAG_GRAPH_EDGES t1 + LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id + WHERE t1.workspace=:workspace + order by t1.CREATETIME DESC + fetch first :limit rows only""" } diff --git a/lightrag/operate.py b/lightrag/operate.py index db7c9401..285b6e35 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -405,12 +405,13 @@ async def local_query( kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) json_text = locate_json_string_body_from_string(result) - + logger.debug("local_query json_text:", json_text) try: keywords_data = json.loads(json_text) keywords = keywords_data.get("low_level_keywords", []) keywords = ", ".join(keywords) except json.JSONDecodeError: + print(result) try: result = ( result.replace(kw_prompt[:-1], "") @@ -443,6 +444,8 @@ async def local_query( sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) + if query_param.only_need_prompt: + return sys_prompt response = await use_model_func( query, system_prompt=sys_prompt, @@ -672,12 +675,12 @@ async def global_query( kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) json_text = locate_json_string_body_from_string(result) - + logger.debug("global json_text:", json_text) try: keywords_data = json.loads(json_text) keywords = keywords_data.get("high_level_keywords", []) keywords = ", ".join(keywords) - except json.JSONDecodeError: + except json.JSONDecodeError: try: result = ( result.replace(kw_prompt[:-1], "") @@ -714,6 +717,8 @@ async def global_query( sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) + if query_param.only_need_prompt: + return sys_prompt response = await use_model_func( query, system_prompt=sys_prompt, @@ -914,6 +919,7 @@ async def hybrid_query( result = await use_model_func(kw_prompt) json_text = locate_json_string_body_from_string(result) + logger.debug("hybrid_query json_text:", json_text) try: keywords_data = json.loads(json_text) hl_keywords = keywords_data.get("high_level_keywords", []) @@ -969,6 +975,8 @@ async def hybrid_query( sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type ) + if query_param.only_need_prompt: + return sys_prompt response = await use_model_func( query, system_prompt=sys_prompt, @@ -1079,6 +1087,8 @@ async def naive_query( sys_prompt = sys_prompt_temp.format( content_data=section, response_type=query_param.response_type ) + if query_param.only_need_prompt: + return sys_prompt response = await use_model_func( query, system_prompt=sys_prompt, diff --git a/lightrag/utils.py b/lightrag/utils.py index 104c9fec..473c5d39 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -49,7 +49,11 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: """Locate the JSON string body from a string""" maybe_json_str = re.search(r"{.*}", content, re.DOTALL) if maybe_json_str is not None: - return maybe_json_str.group(0) + maybe_json_str = maybe_json_str.group(0) + maybe_json_str = maybe_json_str.replace("\\n", "") + maybe_json_str = maybe_json_str.replace("\n", "") + maybe_json_str = maybe_json_str.replace("'", '"') + return maybe_json_str else: return None From af3aef5d88d52c7f5fde264f6fa3685c842f9ad5 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:29:55 +0800 Subject: [PATCH 2/8] Optimization logic --- .gitignore | 2 + examples/lightrag_api_oracle_demo..py | 133 ++++---- examples/lightrag_oracle_demo.py | 2 + lightrag/base.py | 2 + lightrag/kg/oracle_impl.py | 46 ++- lightrag/lightrag.py | 29 +- lightrag/llm.py | 11 +- lightrag/operate.py | 432 +++++++++----------------- lightrag/prompt.py | 80 +++-- lightrag/utils.py | 28 +- 10 files changed, 342 insertions(+), 423 deletions(-) diff --git a/.gitignore b/.gitignore index 942c2c25..01e145a8 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ ignore_this.txt .venv/ *.ignore.* .ruff_cache/ +gui/ +*.log \ No newline at end of file diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index 3bfae452..3b2cafc6 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -1,11 +1,16 @@ + from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi import Query from contextlib import asynccontextmanager from pydantic import BaseModel -from typing import Optional +from typing import Optional,Any +from fastapi.responses import JSONResponse -import sys -import os +import sys, os +print(os.getcwd()) from pathlib import Path +script_directory = Path(__file__).resolve().parent.parent +sys.path.append(os.path.abspath(script_directory)) import asyncio import nest_asyncio @@ -13,15 +18,11 @@ from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np +from datetime import datetime from lightrag.kg.oracle_impl import OracleDB -print(os.getcwd()) - -script_directory = Path(__file__).resolve().parent.parent -sys.path.append(os.path.abspath(script_directory)) - # Apply nest_asyncio to solve event loop issues nest_asyncio.apply() @@ -37,18 +38,16 @@ APIKEY = "ocigenerativeai" # Configure working directory WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") print(f"WORKING_DIR: {WORKING_DIR}") -LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus") +LLM_MODEL = os.environ.get("LLM_MODEL", "cohere.command-r-plus-08-2024") print(f"LLM_MODEL: {LLM_MODEL}") EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "cohere.embed-multilingual-v3.0") print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 512)) print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") - if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) - - + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -78,10 +77,10 @@ async def get_embedding_dim(): embedding_dim = embedding.shape[1] return embedding_dim - async def init(): + # Detect embedding dimension - embedding_dimension = await get_embedding_dim() + embedding_dimension = 1024 #await get_embedding_dim() print(f"Detected embedding dimension: {embedding_dimension}") # Create Oracle DB connection # The `config` parameter is the connection configuration of Oracle DB @@ -89,36 +88,36 @@ async def init(): # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud - oracle_db = OracleDB( - config={ - "user": "", - "password": "", - "dsn": "", - "config_dir": "", - "wallet_location": "", - "wallet_password": "", - "workspace": "", - } # specify which docs you want to store and query - ) + oracle_db = OracleDB(config={ + "user":"", + "password":"", + "dsn":"", + "config_dir":"path_to_config_dir", + "wallet_location":"path_to_wallet_location", + "wallet_password":"wallet_password", + "workspace":"company" + } # specify which docs you want to store and query + ) + # Check if Oracle DB tables exist, if not, tables will be created await oracle_db.check_tables() # Initialize LightRAG - # We use Oracle DB as the KV/vector/graph storage + # We use Oracle DB as the KV/vector/graph storage rag = LightRAG( - enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=512, - func=embedding_func, - ), - graph_storage="OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage", - ) + enable_llm_cache=False, + working_dir=WORKING_DIR, + chunk_token_size=512, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=512, + func=embedding_func, + ), + graph_storage = "OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage" + ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.graph_storage_cls.db = oracle_db @@ -128,6 +127,17 @@ async def init(): return rag +# Extract and Insert into LightRAG storage +#with open("./dickens/book.txt", "r", encoding="utf-8") as f: +# await rag.ainsert(f.read()) + +# # Perform search in different modes +# modes = ["naive", "local", "global", "hybrid"] +# for mode in modes: +# print("="*20, mode, "="*20) +# print(await rag.aquery("这篇文档是关于什么内容的?", param=QueryParam(mode=mode))) +# print("-"*100, "\n") + # Data models @@ -135,7 +145,10 @@ class QueryRequest(BaseModel): query: str mode: str = "hybrid" only_need_context: bool = False + only_need_prompt: bool = False +class DataRequest(BaseModel): + limit: int = 100 class InsertRequest(BaseModel): text: str @@ -143,7 +156,7 @@ class InsertRequest(BaseModel): class Response(BaseModel): status: str - data: Optional[str] = None + data: Optional[Any] = None message: Optional[str] = None @@ -151,7 +164,6 @@ class Response(BaseModel): rag = None # 定义为全局对象 - @asynccontextmanager async def lifespan(app: FastAPI): global rag @@ -160,24 +172,39 @@ async def lifespan(app: FastAPI): yield -app = FastAPI( - title="LightRAG API", description="API for RAG operations", lifespan=lifespan -) - +app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan) @app.post("/query", response_model=Response) async def query_endpoint(request: QueryRequest): - try: + #try: # loop = asyncio.get_event_loop() - result = await rag.aquery( + if request.mode == "naive": + top_k = 3 + else: + top_k = 60 + result = await rag.aquery( request.query, param=QueryParam( - mode=request.mode, only_need_context=request.only_need_context + mode=request.mode, + only_need_context=request.only_need_context, + only_need_prompt=request.only_need_prompt, + top_k=top_k ), ) - return Response(status="success", data=result) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + return Response(status="success", data=result) + # except Exception as e: + # raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/data", response_model=Response) +async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)): + if type == "nodes": + result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit) + elif type == "edges": + result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit) + elif type == "statistics": + result = await rag.chunk_entity_relation_graph.get_statistics() + return Response(status="success", data=result) @app.post("/insert", response_model=Response) @@ -220,7 +247,7 @@ async def health_check(): if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8020) + uvicorn.run(app, host="127.0.0.1", port=8020) # Usage example # To run the server, use the following command in your terminal: @@ -237,4 +264,4 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" +# curl -X GET "http://127.0.0.1:8020/health" \ No newline at end of file diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 365b6225..b915c76b 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -97,6 +97,8 @@ async def main(): graph_storage="OracleGraphStorage", kv_storage="OracleKVStorage", vector_storage="OracleVectorDBStorage", + + addon_params = {"example_number":1, "language":"Simplfied Chinese"}, ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool diff --git a/lightrag/base.py b/lightrag/base.py index ca46057f..ea84c000 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -21,6 +21,8 @@ class QueryParam: response_type: str = "Multiple Paragraphs" # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. top_k: int = 60 + # Number of document chunks to retrieve. + # top_n: int = 10 # Number of tokens for the original chunks. max_token_for_text_unit: int = 4000 # Number of tokens for the relationship descriptions diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index e81c32d0..2e394b8a 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -333,6 +333,8 @@ class OracleGraphStorage(BaseGraphStorage): entity_type = node_data["entity_type"] description = node_data["description"] source_id = node_data["source_id"] + logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") + content = entity_name + description contents = [content] batches = [ @@ -369,6 +371,8 @@ class OracleGraphStorage(BaseGraphStorage): keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] + logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}") + content = keywords + source_name + target_name + description contents = [content] batches = [ @@ -544,6 +548,14 @@ class OracleGraphStorage(BaseGraphStorage): res = await self.db.query(sql=SQL,params=params, multirows=True) if res: return res + + async def get_statistics(self): + SQL = SQL_TEMPLATES["get_statistics"] + params = {"workspace":self.db.workspace} + res = await self.db.query(sql=SQL,params=params, multirows=True) + if res: + return res + N_T = { "full_docs": "LIGHTRAG_DOC_FULL", "text_chunks": "LIGHTRAG_DOC_CHUNKS", @@ -715,18 +727,36 @@ SQL_TEMPLATES = { WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, - "get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content - FROM LIGHTRAG_GRAPH_NODES t1 - LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id - WHERE t1.workspace=:workspace - order by t1.CREATETIME DESC - fetch first :limit rows only - """, + "get_all_nodes":"""WITH t0 AS ( + SELECT name AS id, entity_type AS label, entity_type, description, + '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids + FROM lightrag_graph_nodes + WHERE workspace = :workspace + ORDER BY createtime DESC fetch first :limit rows only + ), t1 AS ( + SELECT t0.id, source_chunk_id + FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) ) + ), t2 AS ( + SELECT t1.id, LISTAGG(t2.content, '\n') content + FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id + GROUP BY t1.id + ) + SELECT t0.id, label, entity_type, description, t2.content + FROM t0 LEFT JOIN t2 ON t0.id = t2.id""", "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, t1.weight,t1.DESCRIPTION,t2.content FROM LIGHTRAG_GRAPH_EDGES t1 LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id WHERE t1.workspace=:workspace order by t1.CREATETIME DESC - fetch first :limit rows only""" + fetch first :limit rows only""", + "get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count, + count(distinct CASE WHEN type='edge' THEN id END) as edges_count + FROM ( + select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph + MATCH (a) WHERE a.workspace=:workspace columns(a.name as id)) + UNION + select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph + MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) + )""", } diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 50e33405..2687877a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -12,9 +12,8 @@ from .llm import ( from .operate import ( chunking_by_token_size, extract_entities, - local_query, - global_query, - hybrid_query, + # local_query,global_query,hybrid_query, + kg_query, naive_query, ) @@ -309,28 +308,8 @@ class LightRAG: return loop.run_until_complete(self.aquery(query, param)) async def aquery(self, query: str, param: QueryParam = QueryParam()): - if param.mode == "local": - response = await local_query( - query, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.text_chunks, - param, - asdict(self), - ) - elif param.mode == "global": - response = await global_query( - query, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.text_chunks, - param, - asdict(self), - ) - elif param.mode == "hybrid": - response = await hybrid_query( + if param.mode in ["local", "global", "hybrid"]: + response = await kg_query( query, self.chunk_entity_relation_graph, self.entities_vdb, diff --git a/lightrag/llm.py b/lightrag/llm.py index f4045e80..6263f153 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -69,12 +69,15 @@ async def openai_complete_if_cache( response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) - + content = response.choices[0].message.content + if r'\u' in content: + content = content.encode('utf-8').decode('unicode_escape') + print(content) if hashing_kv is not None: await hashing_kv.upsert( {args_hash: {"return": response.choices[0].message.content, "model": model}} ) - return response.choices[0].message.content + return content @retry( @@ -539,7 +542,7 @@ async def openai_embedding( texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, - api_key: str = None, + api_key: str = None ) -> np.ndarray: if api_key: os.environ["OPENAI_API_KEY"] = api_key @@ -548,7 +551,7 @@ async def openai_embedding( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) response = await openai_async_client.embeddings.create( - model=model, input=texts, encoding_format="float" + model=model, input=texts, encoding_format="float" ) return np.array([dp.embedding for dp in response.data]) diff --git a/lightrag/operate.py b/lightrag/operate.py index 285b6e35..12f78dcd 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -248,14 +248,23 @@ async def extract_entities( entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] ordered_chunks = list(chunks.items()) - + # add language and example number params to prompt + language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"]) + example_number = global_config["addon_params"].get("example_number", None) + if example_number and example_number str: context = None + example_number = global_config["addon_params"].get("example_number", None) + if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): + examples = "\n".join(PROMPTS["keywords_extraction_examples"][:int(example_number)]) + else: + examples="\n".join(PROMPTS["keywords_extraction_examples"]) + + # Set mode + if query_param.mode not in ["local", "global", "hybrid"]: + logger.error(f"Unknown mode {query_param.mode} in kg_query") + return PROMPTS["fail_response"] + + # LLM generate keywords use_model_func = global_config["llm_model_func"] - kw_prompt_temp = PROMPTS["keywords_extraction"] - kw_prompt = kw_prompt_temp.format(query=query) - result = await use_model_func(kw_prompt) - json_text = locate_json_string_body_from_string(result) - logger.debug("local_query json_text:", json_text) + kw_prompt = kw_prompt_temp.format(query=query,examples=examples) + result = await use_model_func(kw_prompt) + logger.info(f"kw_prompt result:") + print(result) try: + json_text = locate_json_string_body_from_string(result) keywords_data = json.loads(json_text) - keywords = keywords_data.get("low_level_keywords", []) - keywords = ", ".join(keywords) - except json.JSONDecodeError: - print(result) - try: - result = ( - result.replace(kw_prompt[:-1], "") - .replace("user", "") - .replace("model", "") - .strip() - ) - result = "{" + result.split("{")[1].split("}")[0] + "}" - - keywords_data = json.loads(result) - keywords = keywords_data.get("low_level_keywords", []) - keywords = ", ".join(keywords) - # Handle parsing error - except json.JSONDecodeError as e: - print(f"JSON parsing error: {e}") - return PROMPTS["fail_response"] - if keywords: - context = await _build_local_query_context( + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + + # Handle parsing error + except json.JSONDecodeError as e: + print(f"JSON parsing error: {e} {result}") + return PROMPTS["fail_response"] + + # Handdle keywords missing + if hl_keywords == [] and ll_keywords == []: + logger.warning("low_level_keywords and high_level_keywords is empty") + return PROMPTS["fail_response"] + if ll_keywords == [] and query_param.mode in ["local","hybrid"]: + logger.warning("low_level_keywords is empty") + return PROMPTS["fail_response"] + else: + ll_keywords = ", ".join(ll_keywords) + if hl_keywords == [] and query_param.mode in ["global","hybrid"]: + logger.warning("high_level_keywords is empty") + return PROMPTS["fail_response"] + else: + hl_keywords = ", ".join(hl_keywords) + + # Build context + keywords = [ll_keywords, hl_keywords] + context = await _build_query_context( keywords, knowledge_graph_inst, entities_vdb, + relationships_vdb, text_chunks_db, query_param, ) + if query_param.only_need_context: return context if context is None: @@ -443,13 +468,13 @@ async def local_query( sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type - ) + ) if query_param.only_need_prompt: return sys_prompt response = await use_model_func( query, system_prompt=sys_prompt, - ) + ) if len(response) > len(sys_prompt): response = ( response.replace(sys_prompt, "") @@ -464,22 +489,87 @@ async def local_query( return response -async def _build_local_query_context( +async def _build_query_context( + query: list, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + ): + ll_kewwords, hl_keywrds = query[0], query[1] + if query_param.mode in ["local", "hybrid"]: + if ll_kewwords == "": + ll_entities_context,ll_relations_context,ll_text_units_context = "","","" + warnings.warn("Low Level context is None. Return empty Low entity/relationship/source") + query_param.mode = "global" + else: + ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data( + ll_kewwords, + knowledge_graph_inst, + entities_vdb, + text_chunks_db, + query_param + ) + if query_param.mode in ["global", "hybrid"]: + if hl_keywrds == "": + hl_entities_context,hl_relations_context,hl_text_units_context = "","","" + warnings.warn("High Level context is None. Return empty High entity/relationship/source") + query_param.mode = "local" + else: + hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data( + hl_keywrds, + knowledge_graph_inst, + relationships_vdb, + text_chunks_db, + query_param + ) + if query_param.mode == 'hybrid': + entities_context,relations_context,text_units_context = combine_contexts( + [hl_entities_context,ll_entities_context], + [hl_relations_context,ll_relations_context], + [hl_text_units_context,ll_text_units_context] + ) + elif query_param.mode == 'local': + entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context + elif query_param.mode == 'global': + entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context + return f""" +# -----Entities----- +# ```csv +# {entities_context} +# ``` +# -----Relationships----- +# ```csv +# {relations_context} +# ``` +# -----Sources----- +# ```csv +# {text_units_context} +# ``` +# """ + + + +async def _get_node_data( query, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, ): + # 获取相似的实体 results = await entities_vdb.query(query, top_k=query_param.top_k) - if not len(results): return None + # 获取实体信息 node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] ) if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") + + # 获取实体的度 node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] ) @@ -488,15 +578,19 @@ async def _build_local_query_context( for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. + # 根据实体获取文本片段 use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst ) + # 获取关联的边 use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, knowledge_graph_inst ) logger.info( f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" - ) + ) + + # 构建提示词 entites_section_list = [["id", "entity", "type", "description", "rank"]] for i, n in enumerate(node_datas): entites_section_list.append( @@ -531,20 +625,7 @@ async def _build_local_query_context( for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) - return f""" ------Entities----- -```csv -{entities_context} -``` ------Relationships----- -```csv -{relations_context} -``` ------Sources----- -```csv -{text_units_context} -``` -""" + return entities_context,relations_context,text_units_context async def _find_most_related_text_unit_from_entities( @@ -659,88 +740,9 @@ async def _find_most_related_edges_from_entities( return all_edges_data -async def global_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, -) -> str: - context = None - use_model_func = global_config["llm_model_func"] - - kw_prompt_temp = PROMPTS["keywords_extraction"] - kw_prompt = kw_prompt_temp.format(query=query) - result = await use_model_func(kw_prompt) - json_text = locate_json_string_body_from_string(result) - logger.debug("global json_text:", json_text) - try: - keywords_data = json.loads(json_text) - keywords = keywords_data.get("high_level_keywords", []) - keywords = ", ".join(keywords) - except json.JSONDecodeError: - try: - result = ( - result.replace(kw_prompt[:-1], "") - .replace("user", "") - .replace("model", "") - .strip() - ) - result = "{" + result.split("{")[1].split("}")[0] + "}" - - keywords_data = json.loads(result) - keywords = keywords_data.get("high_level_keywords", []) - keywords = ", ".join(keywords) - - except json.JSONDecodeError as e: - # Handle parsing error - print(f"JSON parsing error: {e}") - return PROMPTS["fail_response"] - if keywords: - context = await _build_global_query_context( - keywords, - knowledge_graph_inst, - entities_vdb, - relationships_vdb, - text_chunks_db, - query_param, - ) - - if query_param.only_need_context: - return context - if context is None: - return PROMPTS["fail_response"] - - sys_prompt_temp = PROMPTS["rag_response"] - sys_prompt = sys_prompt_temp.format( - context_data=context, response_type=query_param.response_type - ) - if query_param.only_need_prompt: - return sys_prompt - response = await use_model_func( - query, - system_prompt=sys_prompt, - ) - if len(response) > len(sys_prompt): - response = ( - response.replace(sys_prompt, "") - .replace("user", "") - .replace("model", "") - .replace(query, "") - .replace("", "") - .replace("", "") - .strip() - ) - - return response - - -async def _build_global_query_context( +async def _get_edge_data( keywords, knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, @@ -782,6 +784,7 @@ async def _build_global_query_context( logger.info( f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" ) + relations_section_list = [ ["id", "source", "target", "description", "keywords", "weight", "rank"] ] @@ -816,21 +819,8 @@ async def _build_global_query_context( for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) + return entities_context,relations_context,text_units_context - return f""" ------Entities----- -```csv -{entities_context} -``` ------Relationships----- -```csv -{relations_context} -``` ------Sources----- -```csv -{text_units_context} -``` -""" async def _find_most_related_entities_from_relationships( @@ -901,137 +891,11 @@ async def _find_related_text_unit_from_relationships( return all_text_units -async def hybrid_query( - query, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], - query_param: QueryParam, - global_config: dict, -) -> str: - low_level_context = None - high_level_context = None - use_model_func = global_config["llm_model_func"] - - kw_prompt_temp = PROMPTS["keywords_extraction"] - kw_prompt = kw_prompt_temp.format(query=query) - - result = await use_model_func(kw_prompt) - json_text = locate_json_string_body_from_string(result) - logger.debug("hybrid_query json_text:", json_text) - try: - keywords_data = json.loads(json_text) - hl_keywords = keywords_data.get("high_level_keywords", []) - ll_keywords = keywords_data.get("low_level_keywords", []) - hl_keywords = ", ".join(hl_keywords) - ll_keywords = ", ".join(ll_keywords) - except json.JSONDecodeError: - try: - result = ( - result.replace(kw_prompt[:-1], "") - .replace("user", "") - .replace("model", "") - .strip() - ) - result = "{" + result.split("{")[1].split("}")[0] + "}" - keywords_data = json.loads(result) - hl_keywords = keywords_data.get("high_level_keywords", []) - ll_keywords = keywords_data.get("low_level_keywords", []) - hl_keywords = ", ".join(hl_keywords) - ll_keywords = ", ".join(ll_keywords) - # Handle parsing error - except json.JSONDecodeError as e: - print(f"JSON parsing error: {e}") - return PROMPTS["fail_response"] - - if ll_keywords: - low_level_context = await _build_local_query_context( - ll_keywords, - knowledge_graph_inst, - entities_vdb, - text_chunks_db, - query_param, - ) - - if hl_keywords: - high_level_context = await _build_global_query_context( - hl_keywords, - knowledge_graph_inst, - entities_vdb, - relationships_vdb, - text_chunks_db, - query_param, - ) - - context = combine_contexts(high_level_context, low_level_context) - - if query_param.only_need_context: - return context - if context is None: - return PROMPTS["fail_response"] - - sys_prompt_temp = PROMPTS["rag_response"] - sys_prompt = sys_prompt_temp.format( - context_data=context, response_type=query_param.response_type - ) - if query_param.only_need_prompt: - return sys_prompt - response = await use_model_func( - query, - system_prompt=sys_prompt, - ) - if len(response) > len(sys_prompt): - response = ( - response.replace(sys_prompt, "") - .replace("user", "") - .replace("model", "") - .replace(query, "") - .replace("", "") - .replace("", "") - .strip() - ) - return response - - -def combine_contexts(high_level_context, low_level_context): +def combine_contexts(entities, relationships, sources): # Function to extract entities, relationships, and sources from context strings - - def extract_sections(context): - entities_match = re.search( - r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL - ) - relationships_match = re.search( - r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL - ) - sources_match = re.search( - r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL - ) - - entities = entities_match.group(1) if entities_match else "" - relationships = relationships_match.group(1) if relationships_match else "" - sources = sources_match.group(1) if sources_match else "" - - return entities, relationships, sources - - # Extract sections from both contexts - - if high_level_context is None: - warnings.warn( - "High Level context is None. Return empty High entity/relationship/source" - ) - hl_entities, hl_relationships, hl_sources = "", "", "" - else: - hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context) - - if low_level_context is None: - warnings.warn( - "Low Level context is None. Return empty Low entity/relationship/source" - ) - ll_entities, ll_relationships, ll_sources = "", "", "" - else: - ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) - + hl_entities, ll_entities = entities[0], entities[1] + hl_relationships, ll_relationships = relationships[0],relationships[1] + hl_sources, ll_sources = sources[0], sources[1] # Combine and deduplicate the entities combined_entities = process_combine_contexts(hl_entities, ll_entities) @@ -1043,21 +907,7 @@ def combine_contexts(high_level_context, low_level_context): # Combine and deduplicate the sources combined_sources = process_combine_contexts(hl_sources, ll_sources) - # Format the combined context - return f""" ------Entities----- -```csv -{combined_entities} -``` ------Relationships----- -```csv -{combined_relationships} -``` ------Sources----- -```csv -{combined_sources} -``` -""" + return combined_entities, combined_relationships, combined_sources async def naive_query( @@ -1080,7 +930,7 @@ async def naive_query( max_token_size=query_param.max_token_for_text_unit, ) logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") - section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) + section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) if query_param.only_need_context: return section sys_prompt_temp = PROMPTS["naive_rag_response"] diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 5de116b3..389f45f2 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -2,6 +2,7 @@ GRAPH_FIELD_SEP = "" PROMPTS = {} +PROMPTS["DEFAULT_LANGUAGE"] = "English" PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" @@ -11,6 +12,7 @@ PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"] PROMPTS["entity_extraction"] = """-Goal- Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. +Use {language} as output language. -Steps- 1. Identify all entities. For each identified entity, extract the following information: @@ -38,7 +40,19 @@ Format the content-level key words as ("content_keywords"{tuple_delimiter} Union[str, None]: """Locate the JSON string body from a string""" - maybe_json_str = re.search(r"{.*}", content, re.DOTALL) - if maybe_json_str is not None: - maybe_json_str = maybe_json_str.group(0) - maybe_json_str = maybe_json_str.replace("\\n", "") - maybe_json_str = maybe_json_str.replace("\n", "") - maybe_json_str = maybe_json_str.replace("'", '"') - return maybe_json_str - else: + try: + maybe_json_str = re.search(r"{.*}", content, re.DOTALL) + if maybe_json_str is not None: + maybe_json_str = maybe_json_str.group(0) + maybe_json_str = maybe_json_str.replace("\\n", "") + maybe_json_str = maybe_json_str.replace("\n", "") + maybe_json_str = maybe_json_str.replace("'", '"') + json.loads(maybe_json_str) + return maybe_json_str + except: + # try: + # content = ( + # content.replace(kw_prompt[:-1], "") + # .replace("user", "") + # .replace("model", "") + # .strip() + # ) + # maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}" + # json.loads(maybe_json_str) + return None From 21f161390a92d404941c4adb3f623870f85b6fde Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:40:38 +0800 Subject: [PATCH 3/8] Logic Optimization --- .gitignore | 2 +- examples/lightrag_api_oracle_demo..py | 109 +++++++++--------- examples/lightrag_oracle_demo.py | 3 +- lightrag/kg/oracle_impl.py | 8 +- lightrag/llm.py | 8 +- lightrag/operate.py | 152 ++++++++++++++++---------- lightrag/prompt.py | 32 +++--- lightrag/utils.py | 7 +- 8 files changed, 185 insertions(+), 136 deletions(-) diff --git a/.gitignore b/.gitignore index 01e145a8..e6f5f5ba 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ ignore_this.txt *.ignore.* .ruff_cache/ gui/ -*.log \ No newline at end of file +*.log diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index c06b8a83..8aaa2cf5 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -1,16 +1,14 @@ - from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi import Query from contextlib import asynccontextmanager from pydantic import BaseModel -from typing import Optional,Any -from fastapi.responses import JSONResponse +from typing import Optional, Any + +import sys +import os + -import sys, os -print(os.getcwd()) from pathlib import Path -script_directory = Path(__file__).resolve().parent.parent -sys.path.append(os.path.abspath(script_directory)) import asyncio import nest_asyncio @@ -18,10 +16,12 @@ from lightrag import LightRAG, QueryParam from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np -from datetime import datetime from lightrag.kg.oracle_impl import OracleDB +print(os.getcwd()) +script_directory = Path(__file__).resolve().parent.parent +sys.path.append(os.path.abspath(script_directory)) # Apply nest_asyncio to solve event loop issues @@ -47,7 +47,8 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) - + + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -77,10 +78,10 @@ async def get_embedding_dim(): embedding_dim = embedding.shape[1] return embedding_dim + async def init(): - # Detect embedding dimension - embedding_dimension = 1024 #await get_embedding_dim() + embedding_dimension = 1024 # await get_embedding_dim() print(f"Detected embedding dimension: {embedding_dimension}") # Create Oracle DB connection # The `config` parameter is the connection configuration of Oracle DB @@ -88,36 +89,36 @@ async def init(): # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud + oracle_db = OracleDB( + config={ + "user": "", + "password": "", + "dsn": "", + "config_dir": "path_to_config_dir", + "wallet_location": "path_to_wallet_location", + "wallet_password": "wallet_password", + "workspace": "company", + } # specify which docs you want to store and query + ) - oracle_db = OracleDB(config={ - "user":"", - "password":"", - "dsn":"", - "config_dir":"path_to_config_dir", - "wallet_location":"path_to_wallet_location", - "wallet_password":"wallet_password", - "workspace":"company" - } # specify which docs you want to store and query - ) - # Check if Oracle DB tables exist, if not, tables will be created await oracle_db.check_tables() # Initialize LightRAG - # We use Oracle DB as the KV/vector/graph storage + # We use Oracle DB as the KV/vector/graph storage rag = LightRAG( - enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=512, - func=embedding_func, - ), - graph_storage = "OracleGraphStorage", - kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage" - ) + enable_llm_cache=False, + working_dir=WORKING_DIR, + chunk_token_size=512, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=512, + func=embedding_func, + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", + vector_storage="OracleVectorDBStorage", + ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.graph_storage_cls.db = oracle_db @@ -128,7 +129,7 @@ async def init(): # Extract and Insert into LightRAG storage -#with open("./dickens/book.txt", "r", encoding="utf-8") as f: +# with open("./dickens/book.txt", "r", encoding="utf-8") as f: # await rag.ainsert(f.read()) # # Perform search in different modes @@ -147,9 +148,11 @@ class QueryRequest(BaseModel): only_need_context: bool = False only_need_prompt: bool = False + class DataRequest(BaseModel): limit: int = 100 + class InsertRequest(BaseModel): text: str @@ -164,6 +167,7 @@ class Response(BaseModel): rag = None + @asynccontextmanager async def lifespan(app: FastAPI): global rag @@ -172,25 +176,28 @@ async def lifespan(app: FastAPI): yield -app = FastAPI(title="LightRAG API", description="API for RAG operations",lifespan=lifespan) +app = FastAPI( + title="LightRAG API", description="API for RAG operations", lifespan=lifespan +) + @app.post("/query", response_model=Response) async def query_endpoint(request: QueryRequest): - #try: - # loop = asyncio.get_event_loop() + # try: + # loop = asyncio.get_event_loop() if request.mode == "naive": top_k = 3 else: top_k = 60 result = await rag.aquery( - request.query, - param=QueryParam( - mode=request.mode, - only_need_context=request.only_need_context, - only_need_prompt=request.only_need_prompt, - top_k=top_k - ), - ) + request.query, + param=QueryParam( + mode=request.mode, + only_need_context=request.only_need_context, + only_need_prompt=request.only_need_prompt, + top_k=top_k, + ), + ) return Response(status="success", data=result) # except Exception as e: # raise HTTPException(status_code=500, detail=str(e)) @@ -199,9 +206,9 @@ async def query_endpoint(request: QueryRequest): @app.get("/data", response_model=Response) async def query_all_nodes(type: str = Query("nodes"), limit: int = Query(100)): if type == "nodes": - result = await rag.chunk_entity_relation_graph.get_all_nodes(limit = limit) + result = await rag.chunk_entity_relation_graph.get_all_nodes(limit=limit) elif type == "edges": - result = await rag.chunk_entity_relation_graph.get_all_edges(limit = limit) + result = await rag.chunk_entity_relation_graph.get_all_edges(limit=limit) elif type == "statistics": result = await rag.chunk_entity_relation_graph.get_statistics() return Response(status="success", data=result) @@ -264,4 +271,4 @@ if __name__ == "__main__": # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' # 4. Health check: -# curl -X GET "http://127.0.0.1:8020/health" \ No newline at end of file +# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index b915c76b..630c1fd8 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -97,8 +97,7 @@ async def main(): graph_storage="OracleGraphStorage", kv_storage="OracleKVStorage", vector_storage="OracleVectorDBStorage", - - addon_params = {"example_number":1, "language":"Simplfied Chinese"}, + addon_params={"example_number": 1, "language": "Simplfied Chinese"}, ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 2cfbd249..08ce79d5 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -114,7 +114,9 @@ class OracleDB: logger.info("Finished check all tables in Oracle database") - async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: + async def query( + self, sql: str, params: dict = None, multirows: bool = False + ) -> Union[dict, None]: async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler @@ -256,7 +258,7 @@ class OracleKVStorage(BaseKVStorage): item["__vector__"], ] # print(merge_sql) - await self.db.execute(merge_sql, data) + await self.db.execute(merge_sql, values) if self.namespace == "full_docs": for k, v in self._data.items(): @@ -266,7 +268,7 @@ class OracleKVStorage(BaseKVStorage): ) values = [k, self._data[k]["content"], self.db.workspace] # print(merge_sql) - await self.db.execute(merge_sql, data) + await self.db.execute(merge_sql, values) return left_data async def index_done_callback(self): diff --git a/lightrag/llm.py b/lightrag/llm.py index 1acf07e0..d3729941 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -70,8 +70,8 @@ async def openai_complete_if_cache( model=model, messages=messages, **kwargs ) content = response.choices[0].message.content - if r'\u' in content: - content = content.encode('utf-8').decode('unicode_escape') + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") print(content) if hashing_kv is not None: await hashing_kv.upsert( @@ -542,7 +542,7 @@ async def openai_embedding( texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, - api_key: str = None + api_key: str = None, ) -> np.ndarray: if api_key: os.environ["OPENAI_API_KEY"] = api_key @@ -551,7 +551,7 @@ async def openai_embedding( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) response = await openai_async_client.embeddings.create( - model=model, input=texts, encoding_format="float" + model=model, input=texts, encoding_format="float" ) return np.array([dp.embedding for dp in response.data]) diff --git a/lightrag/operate.py b/lightrag/operate.py index 1071f8c2..c4740e70 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -249,13 +249,17 @@ async def extract_entities( ordered_chunks = list(chunks.items()) # add language and example number params to prompt - language = global_config["addon_params"].get("language",PROMPTS["DEFAULT_LANGUAGE"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) example_number = global_config["addon_params"].get("example_number", None) - if example_number and example_number len(sys_prompt): response = ( response.replace(sys_prompt, "") @@ -496,44 +504,72 @@ async def _build_query_context( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, - ): +): ll_kewwords, hl_keywrds = query[0], query[1] if query_param.mode in ["local", "hybrid"]: if ll_kewwords == "": - ll_entities_context,ll_relations_context,ll_text_units_context = "","","" - warnings.warn("Low Level context is None. Return empty Low entity/relationship/source") + ll_entities_context, ll_relations_context, ll_text_units_context = ( + "", + "", + "", + ) + warnings.warn( + "Low Level context is None. Return empty Low entity/relationship/source" + ) query_param.mode = "global" else: - ll_entities_context,ll_relations_context,ll_text_units_context = await _get_node_data( + ( + ll_entities_context, + ll_relations_context, + ll_text_units_context, + ) = await _get_node_data( ll_kewwords, knowledge_graph_inst, entities_vdb, text_chunks_db, - query_param - ) + query_param, + ) if query_param.mode in ["global", "hybrid"]: if hl_keywrds == "": - hl_entities_context,hl_relations_context,hl_text_units_context = "","","" - warnings.warn("High Level context is None. Return empty High entity/relationship/source") + hl_entities_context, hl_relations_context, hl_text_units_context = ( + "", + "", + "", + ) + warnings.warn( + "High Level context is None. Return empty High entity/relationship/source" + ) query_param.mode = "local" else: - hl_entities_context,hl_relations_context,hl_text_units_context = await _get_edge_data( + ( + hl_entities_context, + hl_relations_context, + hl_text_units_context, + ) = await _get_edge_data( hl_keywrds, knowledge_graph_inst, relationships_vdb, text_chunks_db, - query_param - ) - if query_param.mode == 'hybrid': - entities_context,relations_context,text_units_context = combine_contexts( - [hl_entities_context,ll_entities_context], - [hl_relations_context,ll_relations_context], - [hl_text_units_context,ll_text_units_context] - ) - elif query_param.mode == 'local': - entities_context,relations_context,text_units_context = ll_entities_context,ll_relations_context,ll_text_units_context - elif query_param.mode == 'global': - entities_context,relations_context,text_units_context = hl_entities_context,hl_relations_context,hl_text_units_context + query_param, + ) + if query_param.mode == "hybrid": + entities_context, relations_context, text_units_context = combine_contexts( + [hl_entities_context, ll_entities_context], + [hl_relations_context, ll_relations_context], + [hl_text_units_context, ll_text_units_context], + ) + elif query_param.mode == "local": + entities_context, relations_context, text_units_context = ( + ll_entities_context, + ll_relations_context, + ll_text_units_context, + ) + elif query_param.mode == "global": + entities_context, relations_context, text_units_context = ( + hl_entities_context, + hl_relations_context, + hl_text_units_context, + ) return f""" # -----Entities----- # ```csv @@ -550,7 +586,6 @@ async def _build_query_context( # """ - async def _get_node_data( query, knowledge_graph_inst: BaseGraphStorage, @@ -568,7 +603,7 @@ async def _get_node_data( ) if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") - + # 获取实体的度 node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] @@ -588,7 +623,7 @@ async def _get_node_data( ) logger.info( f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" - ) + ) # 构建提示词 entites_section_list = [["id", "entity", "type", "description", "rank"]] @@ -625,7 +660,7 @@ async def _get_node_data( for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) - return entities_context,relations_context,text_units_context + return entities_context, relations_context, text_units_context async def _find_most_related_text_unit_from_entities( @@ -821,8 +856,7 @@ async def _get_edge_data( for i, t in enumerate(use_text_units): text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) - return entities_context,relations_context,text_units_context - + return entities_context, relations_context, text_units_context async def _find_most_related_entities_from_relationships( @@ -902,7 +936,7 @@ async def _find_related_text_unit_from_relationships( def combine_contexts(entities, relationships, sources): # Function to extract entities, relationships, and sources from context strings hl_entities, ll_entities = entities[0], entities[1] - hl_relationships, ll_relationships = relationships[0],relationships[1] + hl_relationships, ll_relationships = relationships[0], relationships[1] hl_sources, ll_sources = sources[0], sources[1] # Combine and deduplicate the entities combined_entities = process_combine_contexts(hl_entities, ll_entities) diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 389f45f2..0d4e599d 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -52,7 +52,7 @@ Output: """ PROMPTS["entity_extraction_examples"] = [ -"""Example 1: + """Example 1: Entity_types: [person, technology, mission, organization, location] Text: @@ -77,7 +77,7 @@ Output: ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} #############################""", -"""Example 2: + """Example 2: Entity_types: [person, technology, mission, organization, location] Text: @@ -95,7 +95,7 @@ Output: ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter} ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter} #############################""", -"""Example 3: + """Example 3: Entity_types: [person, role, technology, organization, event, location, concept] Text: @@ -121,10 +121,12 @@ Output: ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter} ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter} ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter} -#############################""" +#############################""", ] -PROMPTS["summarize_entity_descriptions"] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +PROMPTS[ + "summarize_entity_descriptions" +] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. @@ -139,10 +141,14 @@ Description List: {description_list} Output: """ -PROMPTS["entiti_continue_extraction"] = """MANY entities were missed in the last extraction. Add them below using the same format: +PROMPTS[ + "entiti_continue_extraction" +] = """MANY entities were missed in the last extraction. Add them below using the same format: """ -PROMPTS["entiti_if_loop_extraction"] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added. +PROMPTS[ + "entiti_if_loop_extraction" +] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added. """ PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question." @@ -201,7 +207,7 @@ Output: """ PROMPTS["keywords_extraction_examples"] = [ - """Example 1: + """Example 1: Query: "How does international trade influence global economic stability?" ################ @@ -211,7 +217,7 @@ Output: "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"] }} #############################""", - """Example 2: + """Example 2: Query: "What are the environmental consequences of deforestation on biodiversity?" ################ @@ -220,8 +226,8 @@ Output: "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"], "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"] }} -#############################""", - """Example 3: +#############################""", + """Example 3: Query: "What is the role of education in reducing poverty?" ################ @@ -230,8 +236,8 @@ Output: "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"], "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"] }} -#############################""" -] +#############################""", +] PROMPTS["naive_rag_response"] = """---Role--- diff --git a/lightrag/utils.py b/lightrag/utils.py index fc739002..bdd1aa9e 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -56,7 +56,8 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: maybe_json_str = maybe_json_str.replace("'", '"') json.loads(maybe_json_str) return maybe_json_str - except: + except Exception: + pass # try: # content = ( # content.replace(kw_prompt[:-1], "") @@ -64,9 +65,9 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: # .replace("model", "") # .strip() # ) - # maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}" + # maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}" # json.loads(maybe_json_str) - + return None From b5050ba80c22df1cee5a9b10a4121c88e279f810 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:15:10 +0800 Subject: [PATCH 4/8] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 275 ++++++++++++++++++++++--------------- 1 file changed, 166 insertions(+), 109 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 08ce79d5..2e394b8a 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -114,9 +114,7 @@ class OracleDB: logger.info("Finished check all tables in Oracle database") - async def query( - self, sql: str, params: dict = None, multirows: bool = False - ) -> Union[dict, None]: + async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]: async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler @@ -175,11 +173,10 @@ class OracleKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> Union[dict, None]: """根据 id 获取 doc_full 数据.""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( - workspace=self.db.workspace, id=id - ) + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] + params = {"workspace":self.db.workspace, "id":id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: data = res # {"data":res} # print (data) @@ -190,11 +187,11 @@ class OracleKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """根据 id 获取 doc_chunks 数据""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) - ) - # print("get_by_ids:"+SQL) - res = await self.db.query(SQL, multirows=True) + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids])) + params = {"workspace":self.db.workspace} + #print("get_by_ids:"+SQL) + #print(params) + res = await self.db.query(SQL,params, multirows=True) if res: data = res # [{"data":i} for i in res] # print(data) @@ -204,12 +201,16 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" - SQL = SQL_TEMPLATES["filter_keys"].format( - table_name=N_T[self.namespace], - workspace=self.db.workspace, - ids=",".join([f"'{k}'" for k in keys]), - ) - res = await self.db.query(SQL, multirows=True) + SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], + ids=",".join([f"'{id}'" for id in keys])) + params = {"workspace":self.db.workspace} + try: + await self.db.query(SQL, params) + except Exception as e: + logger.error(f"Oracle database error: {e}") + print(SQL) + print(params) + res = await self.db.query(SQL, params,multirows=True) data = None if res: exist_keys = [key["id"] for key in res] @@ -246,29 +247,31 @@ class OracleKVStorage(BaseKVStorage): d["__vector__"] = embeddings[i] # print(list_data) for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) - - values = [ - item["__id__"], - item["content"], - self.db.workspace, - item["tokens"], - item["chunk_order_index"], - item["full_doc_id"], - item["__vector__"], - ] + merge_sql = SQL_TEMPLATES["merge_chunk"] + data = {"check_id":item["__id__"], + "id":item["__id__"], + "content":item["content"], + "workspace":self.db.workspace, + "tokens":item["tokens"], + "chunk_order_index":item["chunk_order_index"], + "full_doc_id":item["full_doc_id"], + "content_vector":item["__vector__"] + } # print(merge_sql) - await self.db.execute(merge_sql, values) + await self.db.execute(merge_sql, data) if self.namespace == "full_docs": for k, v in self._data.items(): # values.clear() - merge_sql = SQL_TEMPLATES["merge_doc_full"].format( - check_id=k, - ) - values = [k, self._data[k]["content"], self.db.workspace] + merge_sql = SQL_TEMPLATES["merge_doc_full"] + data = { + "check_id":k, + "id":k, + "content":v["content"], + "workspace":self.db.workspace + } # print(merge_sql) - await self.db.execute(merge_sql, values) + await self.db.execute(merge_sql, data) return left_data async def index_done_callback(self): @@ -298,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage): # 转换精度 dtype = str(embedding.dtype).upper() dimension = embedding.shape[0] - embedding_string = ", ".join(map(str, embedding.tolist())) + embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]" - SQL = SQL_TEMPLATES[self.namespace].format( - embedding_string=embedding_string, - dimension=dimension, - dtype=dtype, - workspace=self.db.workspace, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) + SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype) + params = { + "embedding_string": embedding_string, + "workspace": self.db.workspace, + "top_k": top_k, + "better_than_threshold": self.cosine_better_than_threshold, + } # print(SQL) - results = await self.db.query(SQL, multirows=True) + results = await self.db.query(SQL,params=params, multirows=True) # print("vector search result:",results) return results @@ -344,22 +346,18 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_node"].format( - workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id - ) + merge_sql = SQL_TEMPLATES["merge_node"] + data = { + "workspace":self.db.workspace, + "name":entity_name, + "entity_type":entity_type, + "description":description, + "source_chunk_id":source_id, + "content":content, + "content_vector":content_vector + } # print(merge_sql) - await self.db.execute( - merge_sql, - [ - self.db.workspace, - entity_name, - entity_type, - description, - source_id, - content, - content_vector, - ], - ) + await self.db.execute(merge_sql,data) # self._graph.add_node(node_id, **node_data) async def upsert_edge( @@ -373,6 +371,8 @@ class OracleGraphStorage(BaseGraphStorage): keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] + logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}") + content = keywords + source_name + target_name + description contents = [content] batches = [ @@ -384,27 +384,20 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_edge"].format( - workspace=self.db.workspace, - source_name=source_name, - target_name=target_name, - source_chunk_id=source_chunk_id, - ) + merge_sql = SQL_TEMPLATES["merge_edge"] + data = { + "workspace":self.db.workspace, + "source_name":source_name, + "target_name":target_name, + "weight":weight, + "keywords":keywords, + "description":description, + "source_chunk_id":source_chunk_id, + "content":content, + "content_vector":content_vector + } # print(merge_sql) - await self.db.execute( - merge_sql, - [ - self.db.workspace, - source_name, - target_name, - weight, - keywords, - description, - source_chunk_id, - content, - content_vector, - ], - ) + await self.db.execute(merge_sql,data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: @@ -434,12 +427,14 @@ class OracleGraphStorage(BaseGraphStorage): #################### query method ################# async def has_node(self, node_id: str) -> bool: """根据节点id检查节点是否存在""" - SQL = SQL_TEMPLATES["has_node"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["has_node"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(SQL) # print(self.db.workspace, node_id) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Node exist!",res) return True @@ -449,13 +444,14 @@ class OracleGraphStorage(BaseGraphStorage): async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """根据源和目标节点id检查边是否存在""" - SQL = SQL_TEMPLATES["has_edge"].format( - workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id, - ) + SQL = SQL_TEMPLATES["has_edge"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id, + "target_node_id":target_node_id + } # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Edge exist!",res) return True @@ -465,11 +461,13 @@ class OracleGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """根据节点id获取节点的度""" - SQL = SQL_TEMPLATES["node_degree"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["node_degree"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Node degree",res["degree"]) return res["degree"] @@ -485,12 +483,14 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: """根据节点id获取节点数据""" - SQL = SQL_TEMPLATES["get_node"].format( - workspace=self.db.workspace, node_id=node_id - ) + SQL = SQL_TEMPLATES["get_node"] + params = { + "workspace":self.db.workspace, + "node_id":node_id + } # print(self.db.workspace, node_id) # print(SQL) - res = await self.db.query(SQL) + res = await self.db.query(SQL,params) if res: # print("Get node!",self.db.workspace, node_id,res) return res @@ -502,12 +502,13 @@ class OracleGraphStorage(BaseGraphStorage): self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """根据源和目标节点id获取边""" - SQL = SQL_TEMPLATES["get_edge"].format( - workspace=self.db.workspace, - source_node_id=source_node_id, - target_node_id=target_node_id, - ) - res = await self.db.query(SQL) + SQL = SQL_TEMPLATES["get_edge"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id, + "target_node_id":target_node_id + } + res = await self.db.query(SQL,params) if res: # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) return res @@ -518,10 +519,12 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node_edges(self, source_node_id: str): """根据节点id获取节点的所有边""" if await self.has_node(source_node_id): - SQL = SQL_TEMPLATES["get_node_edges"].format( - workspace=self.db.workspace, source_node_id=source_node_id - ) - res = await self.db.query(sql=SQL, multirows=True) + SQL = SQL_TEMPLATES["get_node_edges"] + params = { + "workspace":self.db.workspace, + "source_node_id":source_node_id + } + res = await self.db.query(sql=SQL, params=params, multirows=True) if res: data = [(i["source_name"], i["target_name"]) for i in res] # print("Get node edge!",self.db.workspace, source_node_id,data) @@ -529,7 +532,29 @@ class OracleGraphStorage(BaseGraphStorage): else: # print("Node Edge not exist!",self.db.workspace, source_node_id) return [] + + async def get_all_nodes(self, limit: int): + """查询所有节点""" + SQL = SQL_TEMPLATES["get_all_nodes"] + params = {"workspace":self.db.workspace, "limit":str(limit)} + res = await self.db.query(sql=SQL,params=params, multirows=True) + if res: + return res + async def get_all_edges(self, limit: int): + """查询所有边""" + SQL = SQL_TEMPLATES["get_all_edges"] + params = {"workspace":self.db.workspace, "limit":str(limit)} + res = await self.db.query(sql=SQL,params=params, multirows=True) + if res: + return res + + async def get_statistics(self): + SQL = SQL_TEMPLATES["get_statistics"] + params = {"workspace":self.db.workspace} + res = await self.db.query(sql=SQL,params=params, multirows=True) + if res: + return res N_T = { "full_docs": "LIGHTRAG_DOC_FULL", @@ -701,5 +726,37 @@ SQL_TEMPLATES = { ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, + values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, + "get_all_nodes":"""WITH t0 AS ( + SELECT name AS id, entity_type AS label, entity_type, description, + '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids + FROM lightrag_graph_nodes + WHERE workspace = :workspace + ORDER BY createtime DESC fetch first :limit rows only + ), t1 AS ( + SELECT t0.id, source_chunk_id + FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) ) + ), t2 AS ( + SELECT t1.id, LISTAGG(t2.content, '\n') content + FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id + GROUP BY t1.id + ) + SELECT t0.id, label, entity_type, description, t2.content + FROM t0 LEFT JOIN t2 ON t0.id = t2.id""", + "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, + t1.weight,t1.DESCRIPTION,t2.content + FROM LIGHTRAG_GRAPH_EDGES t1 + LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id + WHERE t1.workspace=:workspace + order by t1.CREATETIME DESC + fetch first :limit rows only""", + "get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count, + count(distinct CASE WHEN type='edge' THEN id END) as edges_count + FROM ( + select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph + MATCH (a) WHERE a.workspace=:workspace columns(a.name as id)) + UNION + select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph + MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) + )""", } From 0ffd44b79c1b0b205df4a97b2a2aaa88535d5e51 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:21:01 +0800 Subject: [PATCH 5/8] Update oracle_impl.py --- lightrag/kg/oracle_impl.py | 185 ++++++++++++++++++------------------- 1 file changed, 91 insertions(+), 94 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 2e394b8a..8ed73772 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -114,7 +114,9 @@ class OracleDB: logger.info("Finished check all tables in Oracle database") - async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]: + async def query( + self, sql: str, params: dict = None, multirows: bool = False + ) -> Union[dict, None]: async with self.pool.acquire() as connection: connection.inputtypehandler = self.input_type_handler connection.outputtypehandler = self.output_type_handler @@ -174,9 +176,9 @@ class OracleKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> Union[dict, None]: """根据 id 获取 doc_full 数据.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"workspace":self.db.workspace, "id":id} + params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: data = res # {"data":res} # print (data) @@ -187,11 +189,13 @@ class OracleKVStorage(BaseKVStorage): # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """根据 id 获取 doc_chunks 数据""" - SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids])) - params = {"workspace":self.db.workspace} - #print("get_by_ids:"+SQL) - #print(params) - res = await self.db.query(SQL,params, multirows=True) + SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + # print("get_by_ids:"+SQL) + # print(params) + res = await self.db.query(SQL, params, multirows=True) if res: data = res # [{"data":i} for i in res] # print(data) @@ -201,16 +205,17 @@ class OracleKVStorage(BaseKVStorage): async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" - SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace], - ids=",".join([f"'{id}'" for id in keys])) - params = {"workspace":self.db.workspace} + SQL = SQL_TEMPLATES["filter_keys"].format( + table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) + ) + params = {"workspace": self.db.workspace} try: await self.db.query(SQL, params) except Exception as e: logger.error(f"Oracle database error: {e}") print(SQL) print(params) - res = await self.db.query(SQL, params,multirows=True) + res = await self.db.query(SQL, params, multirows=True) data = None if res: exist_keys = [key["id"] for key in res] @@ -248,15 +253,16 @@ class OracleKVStorage(BaseKVStorage): # print(list_data) for item in list_data: merge_sql = SQL_TEMPLATES["merge_chunk"] - data = {"check_id":item["__id__"], - "id":item["__id__"], - "content":item["content"], - "workspace":self.db.workspace, - "tokens":item["tokens"], - "chunk_order_index":item["chunk_order_index"], - "full_doc_id":item["full_doc_id"], - "content_vector":item["__vector__"] - } + data = { + "check_id": item["__id__"], + "id": item["__id__"], + "content": item["content"], + "workspace": self.db.workspace, + "tokens": item["tokens"], + "chunk_order_index": item["chunk_order_index"], + "full_doc_id": item["full_doc_id"], + "content_vector": item["__vector__"], + } # print(merge_sql) await self.db.execute(merge_sql, data) @@ -265,11 +271,11 @@ class OracleKVStorage(BaseKVStorage): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] data = { - "check_id":k, - "id":k, - "content":v["content"], - "workspace":self.db.workspace - } + "check_id": k, + "id": k, + "content": v["content"], + "workspace": self.db.workspace, + } # print(merge_sql) await self.db.execute(merge_sql, data) return left_data @@ -301,7 +307,7 @@ class OracleVectorDBStorage(BaseVectorStorage): # 转换精度 dtype = str(embedding.dtype).upper() dimension = embedding.shape[0] - embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]" + embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype) params = { @@ -309,9 +315,9 @@ class OracleVectorDBStorage(BaseVectorStorage): "workspace": self.db.workspace, "top_k": top_k, "better_than_threshold": self.cosine_better_than_threshold, - } + } # print(SQL) - results = await self.db.query(SQL,params=params, multirows=True) + results = await self.db.query(SQL, params=params, multirows=True) # print("vector search result:",results) return results @@ -346,18 +352,18 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_node"] + merge_sql = SQL_TEMPLATES["merge_node"] data = { - "workspace":self.db.workspace, - "name":entity_name, - "entity_type":entity_type, - "description":description, - "source_chunk_id":source_id, - "content":content, - "content_vector":content_vector - } + "workspace": self.db.workspace, + "name": entity_name, + "entity_type": entity_type, + "description": description, + "source_chunk_id": source_id, + "content": content, + "content_vector": content_vector, + } # print(merge_sql) - await self.db.execute(merge_sql,data) + await self.db.execute(merge_sql, data) # self._graph.add_node(node_id, **node_data) async def upsert_edge( @@ -371,7 +377,9 @@ class OracleGraphStorage(BaseGraphStorage): keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] - logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}") + logger.debug( + f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" + ) content = keywords + source_name + target_name + description contents = [content] @@ -384,20 +392,20 @@ class OracleGraphStorage(BaseGraphStorage): ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] - merge_sql = SQL_TEMPLATES["merge_edge"] + merge_sql = SQL_TEMPLATES["merge_edge"] data = { - "workspace":self.db.workspace, - "source_name":source_name, - "target_name":target_name, - "weight":weight, - "keywords":keywords, - "description":description, - "source_chunk_id":source_chunk_id, - "content":content, - "content_vector":content_vector - } + "workspace": self.db.workspace, + "source_name": source_name, + "target_name": target_name, + "weight": weight, + "keywords": keywords, + "description": description, + "source_chunk_id": source_chunk_id, + "content": content, + "content_vector": content_vector, + } # print(merge_sql) - await self.db.execute(merge_sql,data) + await self.db.execute(merge_sql, data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: @@ -428,13 +436,10 @@ class OracleGraphStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: """根据节点id检查节点是否存在""" SQL = SQL_TEMPLATES["has_node"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + params = {"workspace": self.db.workspace, "node_id": node_id} # print(SQL) # print(self.db.workspace, node_id) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: # print("Node exist!",res) return True @@ -446,12 +451,12 @@ class OracleGraphStorage(BaseGraphStorage): """根据源和目标节点id检查边是否存在""" SQL = SQL_TEMPLATES["has_edge"] params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id, - "target_node_id":target_node_id - } + "workspace": self.db.workspace, + "source_node_id": source_node_id, + "target_node_id": target_node_id, + } # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: # print("Edge exist!",res) return True @@ -462,12 +467,9 @@ class OracleGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: """根据节点id获取节点的度""" SQL = SQL_TEMPLATES["node_degree"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + params = {"workspace": self.db.workspace, "node_id": node_id} # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: # print("Node degree",res["degree"]) return res["degree"] @@ -484,13 +486,10 @@ class OracleGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: """根据节点id获取节点数据""" SQL = SQL_TEMPLATES["get_node"] - params = { - "workspace":self.db.workspace, - "node_id":node_id - } + params = {"workspace": self.db.workspace, "node_id": node_id} # print(self.db.workspace, node_id) # print(SQL) - res = await self.db.query(SQL,params) + res = await self.db.query(SQL, params) if res: # print("Get node!",self.db.workspace, node_id,res) return res @@ -504,11 +503,11 @@ class OracleGraphStorage(BaseGraphStorage): """根据源和目标节点id获取边""" SQL = SQL_TEMPLATES["get_edge"] params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id, - "target_node_id":target_node_id - } - res = await self.db.query(SQL,params) + "workspace": self.db.workspace, + "source_node_id": source_node_id, + "target_node_id": target_node_id, + } + res = await self.db.query(SQL, params) if res: # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) return res @@ -520,10 +519,7 @@ class OracleGraphStorage(BaseGraphStorage): """根据节点id获取节点的所有边""" if await self.has_node(source_node_id): SQL = SQL_TEMPLATES["get_node_edges"] - params = { - "workspace":self.db.workspace, - "source_node_id":source_node_id - } + params = {"workspace": self.db.workspace, "source_node_id": source_node_id} res = await self.db.query(sql=SQL, params=params, multirows=True) if res: data = [(i["source_name"], i["target_name"]) for i in res] @@ -532,30 +528,31 @@ class OracleGraphStorage(BaseGraphStorage): else: # print("Node Edge not exist!",self.db.workspace, source_node_id) return [] - + async def get_all_nodes(self, limit: int): """查询所有节点""" SQL = SQL_TEMPLATES["get_all_nodes"] - params = {"workspace":self.db.workspace, "limit":str(limit)} - res = await self.db.query(sql=SQL,params=params, multirows=True) + params = {"workspace": self.db.workspace, "limit": str(limit)} + res = await self.db.query(sql=SQL, params=params, multirows=True) if res: return res async def get_all_edges(self, limit: int): """查询所有边""" SQL = SQL_TEMPLATES["get_all_edges"] - params = {"workspace":self.db.workspace, "limit":str(limit)} - res = await self.db.query(sql=SQL,params=params, multirows=True) + params = {"workspace": self.db.workspace, "limit": str(limit)} + res = await self.db.query(sql=SQL, params=params, multirows=True) if res: return res - + async def get_statistics(self): SQL = SQL_TEMPLATES["get_statistics"] - params = {"workspace":self.db.workspace} - res = await self.db.query(sql=SQL,params=params, multirows=True) + params = {"workspace": self.db.workspace} + res = await self.db.query(sql=SQL, params=params, multirows=True) if res: return res + N_T = { "full_docs": "LIGHTRAG_DOC_FULL", "text_chunks": "LIGHTRAG_DOC_CHUNKS", @@ -727,7 +724,7 @@ SQL_TEMPLATES = { WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, - "get_all_nodes":"""WITH t0 AS ( + "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids FROM lightrag_graph_nodes @@ -743,20 +740,20 @@ SQL_TEMPLATES = { ) SELECT t0.id, label, entity_type, description, t2.content FROM t0 LEFT JOIN t2 ON t0.id = t2.id""", - "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, + "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, t1.weight,t1.DESCRIPTION,t2.content FROM LIGHTRAG_GRAPH_EDGES t1 LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id WHERE t1.workspace=:workspace order by t1.CREATETIME DESC fetch first :limit rows only""", - "get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count, + "get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count, count(distinct CASE WHEN type='edge' THEN id END) as edges_count FROM ( - select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph + select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph MATCH (a) WHERE a.workspace=:workspace columns(a.name as id)) UNION - select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph + select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) )""", } From 5bde05ed533db1a9f4f948a76c29940b0f09efd2 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:19:28 +0800 Subject: [PATCH 6/8] add LightRAG init parameters in readme also fix some error --- README.md | 29 ++++++++++++++++++++ examples/lightrag_api_oracle_demo..py | 3 ++- examples/lightrag_oracle_demo.py | 4 +-- lightrag/llm.py | 2 +- lightrag/operate.py | 38 +++++++++++++-------------- 5 files changed, 53 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 6d5af135..ad0b66ef 100644 --- a/README.md +++ b/README.md @@ -511,6 +511,35 @@ if __name__ == "__main__": +### LightRAG init parameters + +| **Parameter** | **Type** | **Explanation** | **Default** | +| --- | --- | --- | --- | +| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | +| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` | +| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` | +| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` | +| **log\_level** | | Log level for application runtime | `logging.DEBUG` | +| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | +| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | +| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | +| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` | +| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` | +| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` | +| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` | +| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` | +| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` | +| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` | +| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` | +| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` | +| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` | +| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` | +| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | | +| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | | +| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | +| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | +| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | + ## API Server Implementation LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests. diff --git a/examples/lightrag_api_oracle_demo..py b/examples/lightrag_api_oracle_demo..py index 8aaa2cf5..774ef61f 100644 --- a/examples/lightrag_api_oracle_demo..py +++ b/examples/lightrag_api_oracle_demo..py @@ -81,7 +81,7 @@ async def get_embedding_dim(): async def init(): # Detect embedding dimension - embedding_dimension = 1024 # await get_embedding_dim() + embedding_dimension = await get_embedding_dim() print(f"Detected embedding dimension: {embedding_dimension}") # Create Oracle DB connection # The `config` parameter is the connection configuration of Oracle DB @@ -105,6 +105,7 @@ async def init(): await oracle_db.check_tables() # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage + # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( enable_llm_cache=False, working_dir=WORKING_DIR, diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 630c1fd8..02fb569d 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -84,6 +84,7 @@ async def main(): # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage + # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( enable_llm_cache=False, working_dir=WORKING_DIR, @@ -96,8 +97,7 @@ async def main(): ), graph_storage="OracleGraphStorage", kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage", - addon_params={"example_number": 1, "language": "Simplfied Chinese"}, + vector_storage="OracleVectorDBStorage" ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool diff --git a/lightrag/llm.py b/lightrag/llm.py index d3729941..6a191a0f 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -72,7 +72,7 @@ async def openai_complete_if_cache( content = response.choices[0].message.content if r"\u" in content: content = content.encode("utf-8").decode("unicode_escape") - print(content) + # print(content) if hashing_kv is not None: await hashing_kv.upsert( {args_hash: {"return": response.choices[0].message.content, "model": model}} diff --git a/lightrag/operate.py b/lightrag/operate.py index c4740e70..4265ebcb 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -571,19 +571,19 @@ async def _build_query_context( hl_text_units_context, ) return f""" -# -----Entities----- -# ```csv -# {entities_context} -# ``` -# -----Relationships----- -# ```csv -# {relations_context} -# ``` -# -----Sources----- -# ```csv -# {text_units_context} -# ``` -# """ +-----Entities----- +```csv +{entities_context} +``` +-----Relationships----- +```csv +{relations_context} +``` +-----Sources----- +```csv +{text_units_context} +``` +""" async def _get_node_data( @@ -593,18 +593,18 @@ async def _get_node_data( text_chunks_db: BaseKVStorage[TextChunkSchema], query_param: QueryParam, ): - # 获取相似的实体 + # get similar entities results = await entities_vdb.query(query, top_k=query_param.top_k) if not len(results): return None - # 获取实体信息 + # get entity information node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] ) if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") - # 获取实体的度 + # get entity degree node_degrees = await asyncio.gather( *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] ) @@ -613,11 +613,11 @@ async def _get_node_data( for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. - # 根据实体获取文本片段 + # get entitytext chunk use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst ) - # 获取关联的边 + # get relate edges use_relations = await _find_most_related_edges_from_entities( node_datas, query_param, knowledge_graph_inst ) @@ -625,7 +625,7 @@ async def _get_node_data( f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" ) - # 构建提示词 + # build prompt entites_section_list = [["id", "entity", "type", "description", "rank"]] for i, n in enumerate(node_datas): entites_section_list.append( From 28bc45c8f59dbc2f4aca96ccd71f3fda5b02756e Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:21:39 +0800 Subject: [PATCH 7/8] fix formate --- examples/lightrag_oracle_demo.py | 2 +- lightrag/operate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 02fb569d..2aa47c78 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -97,7 +97,7 @@ async def main(): ), graph_storage="OracleGraphStorage", kv_storage="OracleKVStorage", - vector_storage="OracleVectorDBStorage" + vector_storage="OracleVectorDBStorage", ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool diff --git a/lightrag/operate.py b/lightrag/operate.py index 4265ebcb..c36af2f3 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -613,7 +613,7 @@ async def _get_node_data( for k, n, d in zip(results, node_datas, node_degrees) if n is not None ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram. - # get entitytext chunk + # get entitytext chunk use_text_units = await _find_most_related_text_unit_from_entities( node_datas, query_param, text_chunks_db, knowledge_graph_inst ) From 69867da89f82381bc0ed7100ada0a20f11ce0c26 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:20:10 +0800 Subject: [PATCH 8/8] Update insert_custom_kg.py --- examples/insert_custom_kg.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/insert_custom_kg.py b/examples/insert_custom_kg.py index bbabe6a9..19da0f29 100644 --- a/examples/insert_custom_kg.py +++ b/examples/insert_custom_kg.py @@ -1,5 +1,5 @@ import os -from lightrag import LightRAG, QueryParam +from lightrag import LightRAG from lightrag.llm import gpt_4o_mini_complete ######### # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() @@ -24,50 +24,50 @@ custom_kg = { "entity_name": "CompanyA", "entity_type": "Organization", "description": "A major technology company", - "source_id": "Source1" + "source_id": "Source1", }, { "entity_name": "ProductX", "entity_type": "Product", "description": "A popular product developed by CompanyA", - "source_id": "Source1" + "source_id": "Source1", }, { "entity_name": "PersonA", "entity_type": "Person", "description": "A renowned researcher in AI", - "source_id": "Source2" + "source_id": "Source2", }, { "entity_name": "UniversityB", "entity_type": "Organization", "description": "A leading university specializing in technology and sciences", - "source_id": "Source2" + "source_id": "Source2", }, { "entity_name": "CityC", "entity_type": "Location", "description": "A large metropolitan city known for its culture and economy", - "source_id": "Source3" + "source_id": "Source3", }, { "entity_name": "EventY", "entity_type": "Event", "description": "An annual technology conference held in CityC", - "source_id": "Source3" + "source_id": "Source3", }, { "entity_name": "CompanyD", "entity_type": "Organization", "description": "A financial services company specializing in insurance", - "source_id": "Source4" + "source_id": "Source4", }, { "entity_name": "ServiceZ", "entity_type": "Service", "description": "An insurance product offered by CompanyD", - "source_id": "Source4" - } + "source_id": "Source4", + }, ], "relationships": [ { @@ -76,7 +76,7 @@ custom_kg = { "description": "CompanyA develops ProductX", "keywords": "develop, produce", "weight": 1.0, - "source_id": "Source1" + "source_id": "Source1", }, { "src_id": "PersonA", @@ -84,7 +84,7 @@ custom_kg = { "description": "PersonA works at UniversityB", "keywords": "employment, affiliation", "weight": 0.9, - "source_id": "Source2" + "source_id": "Source2", }, { "src_id": "CityC", @@ -92,7 +92,7 @@ custom_kg = { "description": "EventY is hosted in CityC", "keywords": "host, location", "weight": 0.8, - "source_id": "Source3" + "source_id": "Source3", }, { "src_id": "CompanyD", @@ -100,9 +100,9 @@ custom_kg = { "description": "CompanyD provides ServiceZ", "keywords": "provide, offer", "weight": 1.0, - "source_id": "Source4" - } - ] + "source_id": "Source4", + }, + ], } rag.insert_custom_kg(custom_kg)