From 957bcf865998e6f806eb02c45cfc5ec243ebeb75 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 7 Jan 2025 13:51:20 +0800 Subject: [PATCH 01/12] Organize files move some test files from root to example --- .DS_Store | Bin 8196 -> 0 bytes .gitignore | 1 + .../get_all_edges_nx.py | 0 test.py => examples/test.py | 0 test_chromadb.py => examples/test_chromadb.py | 0 test_neo4j.py => examples/test_neo4j.py | 0 6 files changed, 1 insertion(+) delete mode 100644 .DS_Store rename get_all_edges_nx.py => examples/get_all_edges_nx.py (100%) rename test.py => examples/test.py (100%) rename test_chromadb.py => examples/test_chromadb.py (100%) rename test_neo4j.py => examples/test_neo4j.py (100%) diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 7489d923a9e7375a0aadb3d45e73e969d67f761d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMTWl0n7(U;$(3#;jz3G5m*ica{3$&KdV$p0{3Y034-f7Wwc4wqxr!!?|b_)_> zjrt@S6K~N(LcBcqAc-1CG-@>AfvEApG)7G1#Y7+U!Ap#qVDLY4W(jTK!GstR<|OC* z^Pm5G=klNLpR;F{F@}zU*~FO67?UY;sV$=73W?i!o|mMkrko@Q&zQwr=4Yq;$*eQ4 zbzBj9Aof7)f!G7F2VxIg4<4X1nJAnagw}r-ck(9aMx@ z0FwL)5GH!2dq9XuMlv1AX(2->^(k%-h@gnB7!Yu>CwX&{=}1lsDc}qO&Jc}^=!SxD zbiyUW<_u{e<1+R@?1A|n;NerjOlGqj8^8Si?q$jHes&XC>4t; z?(ZA(ruqUSDEJ*lew1tdp66ss&qj98wx%*#O{eDuhV5GU5MY}YHQ7I8yQVkQ;T2rd z53zm1peVAE(Uu=M(%iHqm27TlIhslyX=-e3P9-Zv$A9Jm*x1y&aA_iC0!^j>zim7k zjomrV-J7?#eJJNSg$dWsXq7p`$rT(Uu-YA`&wj-6hXdgYle!pZ)Y8AV9GXw zv3-W0vs}}54~&ol-N1gxw(?3PPAL7ZF=5SUOP4LV>1JIYE@bV41v@yb)yuNzt4?33 zXs6`;CXIieXL6(&TJ$=7wJeYF`p&#XrKncd=(j7HA@}>V1kpu-%_T?WSp&(aQC` zdatZ~MHtO?i#Ac1wpTD&u1`0Joe14mpKcVILI=}wwt~?&+HHOC;N;24GCpffQ6_=9THfH1~g+6Hlqbw(S_~UiEix1 z5QZ^=Q5bj_)`S8ZY7{yo^`y7T(2sIEB;r6rbU9e1Y@$5x?R${EiFw z2NxwKRZBHeLRu}Y6VsMTkuf}*UQ&wmIRQ_RK_BeG$T)bV$VhD8vGe+fjB|^5EL2ui z*VL}$0kXL*eThbbcqOh~&f$=@v&bIOHr(gAe4~h2zTy@&F+jYw4$Y0UQUq!hoWb%8 zZ@nt1CbR{_e4*aDhSY_`dZFIYkVt3}1)ESeY)q(H8HJJ(Vw)RCMWnQ}s;X8_K_zTk z?toeTRL(K4E9s3{mw5_7nSsU4)FKs3BtBhP7CS6xI_tTX7fK zLW1rfdhS9G`Y}ip-A5$N!NLKI!A2ed3L#M+3yJz9p24$t4$tES9LEW~hS%{1-o!~F z?T2&dxN9~Y@zH!b7W?w9dqk Date: Fri, 10 Jan 2025 11:36:28 +0800 Subject: [PATCH 02/12] update Oracle support add cache support, fix bug --- examples/lightrag_oracle_demo.py | 45 ++++-- lightrag/kg/oracle_impl.py | 265 +++++++++++++++++++++++++++---- lightrag/lightrag.py | 2 + lightrag/operate.py | 16 +- lightrag/utils.py | 4 +- 5 files changed, 284 insertions(+), 48 deletions(-) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index bbb69319..2cc8e018 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" APIKEY = "ocigenerativeai" CHATMODEL = "cohere.command-r-plus" EMBEDMODEL = "cohere.embed-multilingual-v3.0" - +CHUNK_TOKEN_SIZE = 1024 +MAX_TOKENS = 4000 if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -86,27 +87,49 @@ async def main(): # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( + working_dir=WORKING_DIR, + entity_extract_max_gleaning = 1, + enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, + embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, + enable_llm_cache_for_entity_extract = True, + + chunk_token_size=CHUNK_TOKEN_SIZE, + llm_model_max_token_size = MAX_TOKENS, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dimension, - max_token_size=512, + max_token_size=500, func=embedding_func, - ), - graph_storage="OracleGraphStorage", - kv_storage="OracleKVStorage", + ), + + graph_storage = "OracleGraphStorage", + kv_storage = "OracleKVStorage", vector_storage="OracleVectorDBStorage", + doc_status_storage="OracleDocStatusStorage", + + addon_params = {"example_number":1, "language":"Simplfied Chinese"}, ) - # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.graph_storage_cls.db = oracle_db + # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.key_string_value_json_storage_cls.db = oracle_db rag.vector_db_storage_cls.db = oracle_db - # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c - rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + rag.graph_storage_cls.db = oracle_db + rag.doc_status_storage_cls.db = oracle_db + rag.doc_status.db = oracle_db + rag.full_docs.db = oracle_db + rag.text_chunks.db = oracle_db + rag.llm_response_cache.db = oracle_db + rag.key_string_value_json_storage_cls.db = oracle_db + rag.chunks_vdb.db = oracle_db + rag.relationships_vdb.db = oracle_db + rag.entities_vdb.db = oracle_db + rag.graph_storage_cls.db = oracle_db + rag.chunk_entity_relation_graph.db = oracle_db + rag.llm_response_cache.db = oracle_db + rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + # Extract and Insert into LightRAG storage with open("./dickens/demo.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 34745312..d464bcc4 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -3,7 +3,7 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Union +from typing import Union, List, Dict, Set, Any, Tuple import numpy as np import array @@ -12,6 +12,9 @@ from ..base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, + DocStatusStorage, + DocStatus, + DocProcessingStatus, ) import oracledb @@ -167,6 +170,9 @@ class OracleDB: @dataclass class OracleKVStorage(BaseKVStorage): # should pass db object to self.db + db: OracleDB = None + meta_fields = None + def __post_init__(self): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] @@ -174,28 +180,56 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> Union[dict, None]: - """根据 id 获取 doc_full 数据.""" + """get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL, params) + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + else: + res = await self.db.query(SQL, params) if res: - data = res # {"data":res} - # print (data) - return data + return res + else: + return None + + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + """Specifically for llm_response_cache.""" + SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res else: return None # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: - """根据 id 获取 doc_chunks 数据""" + """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} # print("get_by_ids:"+SQL) - # print(params) res = await self.db.query(SQL, params, multirows=True) + if "llm_response_cache" == self.namespace: + modes = set() + dict_res: dict[str, dict] = {} + for row in res: + modes.add(row["mode"]) + for mode in modes: + if mode not in dict_res: + dict_res[mode] = {} + for row in res: + dict_res[row["mode"]][row["id"]] = row + res = [{k: v} for k, v in dict_res.items()] + if res: data = res # [{"data":i} for i in res] # print(data) @@ -204,7 +238,7 @@ class OracleKVStorage(BaseKVStorage): return None async def filter_keys(self, keys: list[str]) -> set[str]: - """过滤掉重复内容""" + """remove duplicated""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) ) @@ -271,13 +305,26 @@ class OracleKVStorage(BaseKVStorage): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] data = { - "check_id": k, "id": k, "content": v["content"], "workspace": self.db.workspace, } # print(merge_sql) await self.db.execute(merge_sql, data) + + if self.namespace == "llm_response_cache": + for mode, items in data.items(): + for k, v in items.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "cache_mode": mode, + } + + await self.db.execute(upsert_sql, _data) return left_data async def index_done_callback(self): @@ -285,8 +332,99 @@ class OracleKVStorage(BaseKVStorage): logger.info("full doc and chunk data had been saved into oracle db!") +@dataclass +class OracleDocStatusStorage(DocStatusStorage): + """Oracle implementation of document status storage""" + # should pass db object to self.db + db: OracleDB = None + meta_fields = None + + def __post_init__(self): + pass + + async def filter_keys(self, ids: list[str]) -> set[str]: + """Return keys that don't exist in storage""" + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( + ids = ",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + res = await self.db.query(SQL, params, True) + # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. + if res: + existed = set([element["id"] for element in res]) + return set(ids) - existed + else: + return set(ids) + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + SQL = SQL_TEMPLATES["get_status_counts"] + params = {"workspace": self.db.workspace} + res = await self.db.query(SQL, params, True) + # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] + counts = {} + for doc in res: + counts[doc["status"]] = doc["count"] + return counts + + async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]: + """Get all documents by status""" + SQL = SQL_TEMPLATES["get_docs_by_status"] + params = {"workspace": self.db.workspace, "status": status} + res = await self.db.query(SQL, params, True) + # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] + # Converting to be a dict + return { + element["id"]: DocProcessingStatus( + #content_summary=element["content_summary"], + content_summary = "", + content_length=element["CONTENT_LENGTH"], + status=element["STATUS"], + created_at=element["CREATETIME"], + updated_at=element["UPDATETIME"], + chunks_count=-1, + #chunks_count=element["chunks_count"], + ) + for element in res + } + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + return await self.get_docs_by_status(DocStatus.FAILED) + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + return await self.get_docs_by_status(DocStatus.PENDING) + + async def index_done_callback(self): + """Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here""" + logger.info("Doc status had been saved into ORACLE db!") + + async def upsert(self, data: dict[str, dict]): + """Update or insert document status + + Args: + data: Dictionary of document IDs and their status data + """ + SQL = SQL_TEMPLATES["merge_doc_status"] + for k, v in data.items(): + # chunks_count is optional + params = { + "workspace": self.db.workspace, + "id": k, + "content_summary": v["content_summary"], + "content_length": v["content_length"], + "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, + "status": v["status"], + } + await self.db.execute(SQL, params) + return data + + @dataclass class OracleVectorDBStorage(BaseVectorStorage): + # should pass db object to self.db + db: OracleDB = None cosine_better_than_threshold: float = 0.2 def __post_init__(self): @@ -564,13 +702,18 @@ N_T = { TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( - id varchar(256)PRIMARY KEY, + id varchar(256), workspace varchar(1024), doc_name varchar(1024), content CLOB, meta JSON, + content_summary varchar(1024), + content_length NUMBER, + status varchar(256), + chunks_count NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL + updatetime TIMESTAMP DEFAULT NULL, + error varchar(4096) )""" }, "LIGHTRAG_DOC_CHUNKS": { @@ -618,10 +761,16 @@ TABLES = { }, "LIGHTRAG_LLM_CACHE": { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - id varchar(256) PRIMARY KEY, - send clob, - return clob, - model varchar(1024), + id varchar(256) PRIMARY KEY, + workspace varchar(1024), + cache_mode varchar(256), + model_name varchar(256), + original_prompt clob, + return_value clob, + embedding CLOB, + embedding_shape NUMBER, + embedding_min NUMBER, + embedding_max NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL )""" @@ -647,22 +796,70 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + + "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", + + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", + + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", + "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace) values(:id,:content,:workspace) - """, + + "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = :id and a.workspace = :workspace) + WHEN NOT MATCHED THEN + INSERT(id,content,workspace) values(:id,:content,:workspace)""", + "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + USING DUAL + ON (a.id = :check_id) + WHEN NOT MATCHED THEN + INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + + "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a + USING DUAL + ON (a.id = :id) + WHEN NOT MATCHED THEN + INSERT (workspace,id,original_prompt,return_value,cache_mode) + VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode) + WHEN MATCHED THEN UPDATE + SET original_prompt = :original_prompt, + return_value = :return_value, + cache_mode = :cache_mode, + updatetime = SYSDATE""", + + "get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})", + + "get_status_counts": """SELECT status as "status", COUNT(1) as "count" + FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""", + + "get_docs_by_status": """select content_length,status, + TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime + from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""", + + "merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = :id and a.workspace = :workspace) + WHEN NOT MATCHED THEN + INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status) + WHEN MATCHED THEN UPDATE + SET content_summary = :content_summary, + content_length = :content_length, + chunks_count = :chunks_count, + status = :status, + updatetime = SYSDATE""", + # SQL for VectorStorage "entities": """SELECT name as entity_name FROM (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance @@ -714,16 +911,22 @@ SQL_TEMPLATES = { COLUMNS (a.name as source_name,b.name as target_name))""", "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL - ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.name=:name) WHEN NOT MATCHED THEN INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) - values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL - ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cbe49da2..b6d5238e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -79,6 +79,7 @@ Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") +OracleDocStatusStorage = lazy_external_import(".kg.oracle_impl", "OracleDocStatusStorage") MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") @@ -290,6 +291,7 @@ class LightRAG: # kv storage "JsonKVStorage": JsonKVStorage, "OracleKVStorage": OracleKVStorage, + "OracleDocStatusStorage":OracleDocStatusStorage, "MongoKVStorage": MongoKVStorage, "TiDBKVStorage": TiDBKVStorage, # vector storage diff --git a/lightrag/operate.py b/lightrag/operate.py index b2c4d215..45ba9656 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -59,13 +59,15 @@ async def _handle_entity_relation_summary( description: str, global_config: dict, ) -> str: + """Handle entity relation summary + For each entity or relation, input is the combined description of already existing description and new description. + If too long, use LLM to summarize. + """ use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] summary_max_tokens = global_config["entity_summary_to_max_tokens"] - language = global_config["addon_params"].get( - "language", PROMPTS["DEFAULT_LANGUAGE"] - ) + language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"]) tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) if len(tokens) < summary_max_tokens: # No need for summary @@ -139,6 +141,7 @@ async def _merge_nodes_then_upsert( knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): + """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] @@ -319,7 +322,7 @@ async def extract_entities( llm_response_cache.global_config = new_config need_to_restore = True if history_messages: - history = json.dumps(history_messages) + history = json.dumps(history_messages,ensure_ascii=False) _prompt = history + "\n" + input_text else: _prompt = input_text @@ -351,6 +354,11 @@ async def extract_entities( return await use_llm_func(input_text) async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + """"Prpocess a single chunk + Args: + chunk_key_dp (tuple[str, TextChunkSchema]): + ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) + """ nonlocal already_processed, already_entities, already_relations chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] diff --git a/lightrag/utils.py b/lightrag/utils.py index 1f6bf405..56e4191c 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -36,7 +36,7 @@ logger = logging.getLogger("lightrag") def set_logger(log_file: str): logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file) + file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -473,7 +473,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"] llm_model_func = hashing_kv.global_config.get("llm_model_func") current_embedding = await embedding_model_func([prompt]) From ef61ffe444deff085665468e61046822bc84d7f0 Mon Sep 17 00:00:00 2001 From: Gurjot Singh Date: Tue, 14 Jan 2025 22:10:47 +0530 Subject: [PATCH 03/12] Add custom function with separate keyword extraction for user's query and a separate prompt --- lightrag/base.py | 2 + lightrag/lightrag.py | 110 ++++++++++++++++++++++++ lightrag/operate.py | 200 +++++++++++++++++++++++++++++++++++++++++++ test.py | 2 +- 4 files changed, 313 insertions(+), 1 deletion(-) diff --git a/lightrag/base.py b/lightrag/base.py index 94a39cf3..7b3504d0 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -31,6 +31,8 @@ class QueryParam: max_token_for_global_context: int = 4000 # Number of tokens for the entity descriptions max_token_for_local_context: int = 4000 + hl_keywords: list[str] = field(default_factory=list) + ll_keywords: list[str] = field(default_factory=list) @dataclass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 596fbdbf..e8859071 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -17,6 +17,8 @@ from .operate import ( kg_query, naive_query, mix_kg_vector_query, + extract_keywords_only, + kg_query_with_keywords, ) from .utils import ( @@ -753,6 +755,114 @@ class LightRAG: await self._query_done() return response + def query_with_separate_keyword_extraction( + self, + query: str, + prompt: str, + param: QueryParam = QueryParam() + ): + """ + 1. Extract keywords from the 'query' using new function in operate.py. + 2. Then run the standard aquery() flow with the final prompt (formatted_question). + """ + + loop = always_get_an_event_loop() + return loop.run_until_complete(self.aquery_with_separate_keyword_extraction(query, prompt, param)) + + async def aquery_with_separate_keyword_extraction( + self, + query: str, + prompt: str, + param: QueryParam = QueryParam() + ): + """ + 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. + 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. + """ + + # --------------------- + # STEP 1: Keyword Extraction + # --------------------- + # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords). + hl_keywords, ll_keywords = await extract_keywords_only( + text=query, + param=param, + global_config=asdict(self), + hashing_kv=self.llm_response_cache or self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ) + ) + + param.hl_keywords=hl_keywords, + param.ll_keywords=ll_keywords, + + # --------------------- + # STEP 2: Final Query Logic + # --------------------- + + # Create a new string with the prompt and the keywords + ll_keywords_str = ", ".join(ll_keywords) + hl_keywords_str = ", ".join(hl_keywords) + formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" + + if param.mode in ["local", "global", "hybrid"]: + response = await kg_query_with_keywords( + formatted_question, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + elif param.mode == "naive": + response = await naive_query( + formatted_question, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + elif param.mode == "mix": + response = await mix_kg_vector_query( + formatted_question, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + else: + raise ValueError(f"Unknown mode {param.mode}") + + await self._query_done() + return response + async def _query_done(self): tasks = [] for storage_inst in [self.llm_response_cache]: diff --git a/lightrag/operate.py b/lightrag/operate.py index 7216c07f..f4993873 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -680,6 +680,206 @@ async def kg_query( ) return response +async def kg_query_with_keywords( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, +) -> str: + """ + Refactored kg_query that does NOT extract keywords by itself. + It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. + Then it uses those to build context and produce a final LLM response. + """ + + # --------------------------- + # 0) Handle potential cache + # --------------------------- + use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + + # --------------------------- + # 1) RETRIEVE KEYWORDS FROM query_param + # --------------------------- + + # If these fields don't exist, default to empty lists/strings. + hl_keywords = getattr(query_param, "hl_keywords", []) or [] + ll_keywords = getattr(query_param, "ll_keywords", []) or [] + + # If neither has any keywords, you could handle that logic here. + if not hl_keywords and not ll_keywords: + logger.warning("No keywords found in query_param. Could default to global mode or fail.") + return PROMPTS["fail_response"] + if not ll_keywords and query_param.mode in ["local", "hybrid"]: + logger.warning("low_level_keywords is empty, switching to global mode.") + query_param.mode = "global" + if not hl_keywords and query_param.mode in ["global", "hybrid"]: + logger.warning("high_level_keywords is empty, switching to local mode.") + query_param.mode = "local" + + # Flatten low-level and high-level keywords if needed + ll_keywords_flat = [item for sublist in ll_keywords for item in sublist] if any(isinstance(i, list) for i in ll_keywords) else ll_keywords + hl_keywords_flat = [item for sublist in hl_keywords for item in sublist] if any(isinstance(i, list) for i in hl_keywords) else hl_keywords + + # Join the flattened lists + ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else "" + hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else "" + + keywords = [ll_keywords_str, hl_keywords_str] + + logger.info("Using %s mode for query processing", query_param.mode) + + # --------------------------- + # 2) BUILD CONTEXT + # --------------------------- + context = await _build_query_context( + keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + ) + if not context: + return PROMPTS["fail_response"] + + # If only context is needed, return it + if query_param.only_need_context: + return context + + # --------------------------- + # 3) BUILD THE SYSTEM PROMPT + CALL LLM + # --------------------------- + sys_prompt_temp = PROMPTS["rag_response"] + sys_prompt = sys_prompt_temp.format( + context_data=context, response_type=query_param.response_type + ) + + if query_param.only_need_prompt: + return sys_prompt + + # Now call the LLM with the final system prompt + response = await use_model_func( + query, + system_prompt=sys_prompt, + stream=query_param.stream, + ) + + # Clean up the response + if isinstance(response, str) and len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + + # --------------------------- + # 4) SAVE TO CACHE + # --------------------------- + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response + +async def extract_keywords_only( + text: str, + param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, +) -> tuple[list[str], list[str]]: + """ + Extract high-level and low-level keywords from the given 'text' using the LLM. + This method does NOT build the final RAG context or provide a final answer. + It ONLY extracts keywords (hl_keywords, ll_keywords). + """ + + # 1. Handle cache if needed + args_hash = compute_args_hash(param.mode, text) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, text, param.mode + ) + if cached_response is not None: + # parse the cached_response if it’s JSON containing keywords + # or simply return (hl_keywords, ll_keywords) from cached + # Assuming cached_response is in the same JSON structure: + match = re.search(r"\{.*\}", cached_response, re.DOTALL) + if match: + keywords_data = json.loads(match.group(0)) + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + return hl_keywords, ll_keywords + return [], [] + + # 2. Build the examples + example_number = global_config["addon_params"].get("example_number", None) + if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): + examples = "\n".join( + PROMPTS["keywords_extraction_examples"][: int(example_number)] + ) + else: + examples = "\n".join(PROMPTS["keywords_extraction_examples"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) + + # 3. Build the keyword-extraction prompt + kw_prompt_temp = PROMPTS["keywords_extraction"] + kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language) + + # 4. Call the LLM for keyword extraction + use_model_func = global_config["llm_model_func"] + result = await use_model_func(kw_prompt, keyword_extraction=True) + + # 5. Parse out JSON from the LLM response + match = re.search(r"\{.*\}", result, re.DOTALL) + if not match: + logger.error("No JSON-like structure found in the result.") + return [], [] + try: + keywords_data = json.loads(match.group(0)) + except json.JSONDecodeError as e: + logger.error(f"JSON parsing error: {e}") + return [], [] + + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + + # 6. Cache the result if needed + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=result, + prompt=text, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=param.mode, + ), + ) + return hl_keywords, ll_keywords async def _build_query_context( query: list, diff --git a/test.py b/test.py index 80bcaa6d..895f0b30 100644 --- a/test.py +++ b/test.py @@ -39,4 +39,4 @@ print( # Perform hybrid search print( rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) -) +) \ No newline at end of file From bc79f6650e57236b572e81e022d3fdbfadb51552 Mon Sep 17 00:00:00 2001 From: Gurjot Singh Date: Tue, 14 Jan 2025 22:23:14 +0530 Subject: [PATCH 04/12] Fix linting errors --- lightrag/lightrag.py | 40 ++++++++++++++++++++-------------------- lightrag/operate.py | 23 ++++++++++++++++++----- test.py | 2 +- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e8859071..cacdfc50 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -756,10 +756,7 @@ class LightRAG: return response def query_with_separate_keyword_extraction( - self, - query: str, - prompt: str, - param: QueryParam = QueryParam() + self, query: str, prompt: str, param: QueryParam = QueryParam() ): """ 1. Extract keywords from the 'query' using new function in operate.py. @@ -767,13 +764,12 @@ class LightRAG: """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery_with_separate_keyword_extraction(query, prompt, param)) - + return loop.run_until_complete( + self.aquery_with_separate_keyword_extraction(query, prompt, param) + ) + async def aquery_with_separate_keyword_extraction( - self, - query: str, - prompt: str, - param: QueryParam = QueryParam() + self, query: str, prompt: str, param: QueryParam = QueryParam() ): """ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. @@ -788,20 +784,21 @@ class LightRAG: text=query, param=param, global_config=asdict(self), - hashing_kv=self.llm_response_cache or self.key_string_value_json_storage_cls( + hashing_kv=self.llm_response_cache + or self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), embedding_func=None, - ) + ), ) - - param.hl_keywords=hl_keywords, - param.ll_keywords=ll_keywords, - + + param.hl_keywords = (hl_keywords,) + param.ll_keywords = (ll_keywords,) + # --------------------- # STEP 2: Final Query Logic # --------------------- - + # Create a new string with the prompt and the keywords ll_keywords_str = ", ".join(ll_keywords) hl_keywords_str = ", ".join(hl_keywords) @@ -817,7 +814,8 @@ class LightRAG: param, asdict(self), hashing_kv=self.llm_response_cache - if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -832,7 +830,8 @@ class LightRAG: param, asdict(self), hashing_kv=self.llm_response_cache - if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), @@ -850,7 +849,8 @@ class LightRAG: param, asdict(self), hashing_kv=self.llm_response_cache - if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") else self.key_string_value_json_storage_cls( namespace="llm_response_cache", global_config=asdict(self), diff --git a/lightrag/operate.py b/lightrag/operate.py index f4993873..7df489b3 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -680,6 +680,7 @@ async def kg_query( ) return response + async def kg_query_with_keywords( query: str, knowledge_graph_inst: BaseGraphStorage, @@ -717,8 +718,10 @@ async def kg_query_with_keywords( # If neither has any keywords, you could handle that logic here. if not hl_keywords and not ll_keywords: - logger.warning("No keywords found in query_param. Could default to global mode or fail.") - return PROMPTS["fail_response"] + logger.warning( + "No keywords found in query_param. Could default to global mode or fail." + ) + return PROMPTS["fail_response"] if not ll_keywords and query_param.mode in ["local", "hybrid"]: logger.warning("low_level_keywords is empty, switching to global mode.") query_param.mode = "global" @@ -727,8 +730,16 @@ async def kg_query_with_keywords( query_param.mode = "local" # Flatten low-level and high-level keywords if needed - ll_keywords_flat = [item for sublist in ll_keywords for item in sublist] if any(isinstance(i, list) for i in ll_keywords) else ll_keywords - hl_keywords_flat = [item for sublist in hl_keywords for item in sublist] if any(isinstance(i, list) for i in hl_keywords) else hl_keywords + ll_keywords_flat = ( + [item for sublist in ll_keywords for item in sublist] + if any(isinstance(i, list) for i in ll_keywords) + else ll_keywords + ) + hl_keywords_flat = ( + [item for sublist in hl_keywords for item in sublist] + if any(isinstance(i, list) for i in hl_keywords) + else hl_keywords + ) # Join the flattened lists ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else "" @@ -766,7 +777,7 @@ async def kg_query_with_keywords( if query_param.only_need_prompt: return sys_prompt - + # Now call the LLM with the final system prompt response = await use_model_func( query, @@ -803,6 +814,7 @@ async def kg_query_with_keywords( ) return response + async def extract_keywords_only( text: str, param: QueryParam, @@ -881,6 +893,7 @@ async def extract_keywords_only( ) return hl_keywords, ll_keywords + async def _build_query_context( query: list, knowledge_graph_inst: BaseGraphStorage, diff --git a/test.py b/test.py index 895f0b30..80bcaa6d 100644 --- a/test.py +++ b/test.py @@ -39,4 +39,4 @@ print( # Perform hybrid search print( rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) -) \ No newline at end of file +) From d91a330e9dcb50ee08e438cd6a2592bd1d145087 Mon Sep 17 00:00:00 2001 From: Samuel Chan Date: Wed, 15 Jan 2025 12:02:55 +0800 Subject: [PATCH 05/12] Enrich README.md for postgres usage, make some change to cater python version<12 --- README.md | 1 + examples/copy_postgres_llm_cache_to_json.py | 66 +++++++++++++++++++++ lightrag/kg/postgres_impl.py | 15 ++++- 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 examples/copy_postgres_llm_cache_to_json.py diff --git a/README.md b/README.md index e8401a3d..2178c3ab 100644 --- a/README.md +++ b/README.md @@ -360,6 +360,7 @@ see test_neo4j.py for a working example. ### Using PostgreSQL for Storage For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE). * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac. +* If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) * Create index for AGE example: (Change below `dickens` to your graph name if necessary) ``` diff --git a/examples/copy_postgres_llm_cache_to_json.py b/examples/copy_postgres_llm_cache_to_json.py new file mode 100644 index 00000000..f5fa0d51 --- /dev/null +++ b/examples/copy_postgres_llm_cache_to_json.py @@ -0,0 +1,66 @@ +import asyncio +import logging +import os +from dotenv import load_dotenv + +from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage +from lightrag.storage import JsonKVStorage + +load_dotenv() +ROOT_DIR = os.environ.get("ROOT_DIR") +WORKING_DIR = f"{ROOT_DIR}/dickens-pg" + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# AGE +os.environ["AGE_GRAPH_NAME"] = "chinese" + +postgres_db = PostgreSQLDB( + config={ + "host": "localhost", + "port": 15432, + "user": "rag", + "password": "rag", + "database": "r1", + } +) + + +async def main(): + await postgres_db.initdb() + + from_llm_response_cache = PGKVStorage( + namespace="llm_response_cache", + global_config={"embedding_batch_num": 6}, + embedding_func=None, + db=postgres_db, + ) + + to_llm_response_cache = JsonKVStorage( + namespace="llm_response_cache", + global_config={"working_dir": WORKING_DIR}, + embedding_func=None, + ) + + kv = {} + for c_id in await from_llm_response_cache.all_keys(): + print(f"Copying {c_id}") + workspace = c_id["workspace"] + mode = c_id["mode"] + _id = c_id["id"] + postgres_db.workspace = workspace + obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id) + if mode not in kv: + kv[mode] = {} + kv[mode][_id] = obj[_id] + print(f"Object {obj}") + await to_llm_response_cache.upsert(kv) + await to_llm_response_cache.index_done_callback() + print("Mission accomplished!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index b93a345b..86072c9f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -231,6 +231,16 @@ class PGKVStorage(BaseKVStorage): else: return None + async def all_keys(self) -> list[dict]: + if "llm_response_cache" == self.namespace: + sql = "select workspace,mode,id from lightrag_llm_cache" + res = await self.db.query(sql, multirows=True) + return res + else: + logger.error( + f"all_keys is only implemented for llm_response_cache, not for {self.namespace}" + ) + async def filter_keys(self, keys: List[str]) -> Set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( @@ -412,7 +422,10 @@ class PGDocStatusStorage(DocStatusStorage): async def filter_keys(self, data: list[str]) -> set[str]: """Return keys that don't exist in storage""" - sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})" + keys = ",".join([f"'{_id}'" for _id in data]) + sql = ( + f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})" + ) result = await self.db.query(sql, {"workspace": self.db.workspace}, True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: From d1ba8c5db59e12e80bd948be024dff19ad2dc5bf Mon Sep 17 00:00:00 2001 From: Samuel Chan Date: Thu, 16 Jan 2025 07:56:13 +0800 Subject: [PATCH 06/12] Add some script in examples to copy llm cache from one solution to another --- ...y => copy_llm_cache_to_another_storage.py} | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) rename examples/{copy_postgres_llm_cache_to_json.py => copy_llm_cache_to_another_storage.py} (57%) diff --git a/examples/copy_postgres_llm_cache_to_json.py b/examples/copy_llm_cache_to_another_storage.py similarity index 57% rename from examples/copy_postgres_llm_cache_to_json.py rename to examples/copy_llm_cache_to_another_storage.py index f5fa0d51..b9378c7c 100644 --- a/examples/copy_postgres_llm_cache_to_json.py +++ b/examples/copy_llm_cache_to_another_storage.py @@ -1,3 +1,9 @@ +""" +Sometimes you need to switch a storage solution, but you want to save LLM token and time. +This handy script helps you to copy the LLM caches from one storage solution to another. +(Not all the storage impl are supported) +""" + import asyncio import logging import os @@ -8,7 +14,7 @@ from lightrag.storage import JsonKVStorage load_dotenv() ROOT_DIR = os.environ.get("ROOT_DIR") -WORKING_DIR = f"{ROOT_DIR}/dickens-pg" +WORKING_DIR = f"{ROOT_DIR}/dickens" logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) @@ -24,12 +30,12 @@ postgres_db = PostgreSQLDB( "port": 15432, "user": "rag", "password": "rag", - "database": "r1", + "database": "r2", } ) -async def main(): +async def copy_from_postgres_to_json(): await postgres_db.initdb() from_llm_response_cache = PGKVStorage( @@ -62,5 +68,30 @@ async def main(): print("Mission accomplished!") +async def copy_from_json_to_postgres(): + await postgres_db.initdb() + + from_llm_response_cache = JsonKVStorage( + namespace="llm_response_cache", + global_config={"working_dir": WORKING_DIR}, + embedding_func=None, + ) + + to_llm_response_cache = PGKVStorage( + namespace="llm_response_cache", + global_config={"embedding_batch_num": 6}, + embedding_func=None, + db=postgres_db, + ) + + for mode in await from_llm_response_cache.all_keys(): + print(f"Copying {mode}") + caches = await from_llm_response_cache.get_by_id(mode) + for k, v in caches.items(): + item = {mode: {k: v}} + print(f"\tCopying {item}") + await to_llm_response_cache.upsert(item) + + if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(copy_from_json_to_postgres()) From d5ae6669eafafa2c3ee890729965f34d09e03eca Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:52:37 +0800 Subject: [PATCH 07/12] support pipeline mode --- examples/lightrag_oracle_demo.py | 46 ++-- lightrag/kg/oracle_impl.py | 224 +++++------------- lightrag/lightrag.py | 392 +++++++++++++++++++++---------- lightrag/operate.py | 27 ++- lightrag/utils.py | 8 +- 5 files changed, 374 insertions(+), 323 deletions(-) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 2cc8e018..8a5439e2 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -87,12 +87,14 @@ async def main(): # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( - working_dir=WORKING_DIR, + # log_level="DEBUG", + working_dir=WORKING_DIR, entity_extract_max_gleaning = 1, - enable_llm_cache=False, - embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, + enable_llm_cache=True, enable_llm_cache_for_entity_extract = True, + embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, + chunk_token_size=CHUNK_TOKEN_SIZE, llm_model_max_token_size = MAX_TOKENS, @@ -106,34 +108,30 @@ async def main(): graph_storage = "OracleGraphStorage", kv_storage = "OracleKVStorage", vector_storage="OracleVectorDBStorage", - doc_status_storage="OracleDocStatusStorage", - addon_params = {"example_number":1, "language":"Simplfied Chinese"}, + addon_params = {"example_number":1, + "language":"Simplfied Chinese", + "entity_types": ["organization", "person", "geo", "event"], + "insert_batch_size":2, + } ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.key_string_value_json_storage_cls.db = oracle_db - rag.vector_db_storage_cls.db = oracle_db - rag.graph_storage_cls.db = oracle_db - rag.doc_status_storage_cls.db = oracle_db - rag.doc_status.db = oracle_db - rag.full_docs.db = oracle_db - rag.text_chunks.db = oracle_db - rag.llm_response_cache.db = oracle_db - rag.key_string_value_json_storage_cls.db = oracle_db - rag.chunks_vdb.db = oracle_db - rag.relationships_vdb.db = oracle_db - rag.entities_vdb.db = oracle_db - rag.graph_storage_cls.db = oracle_db - rag.chunk_entity_relation_graph.db = oracle_db - rag.llm_response_cache.db = oracle_db + rag.set_storage_client(db_client = oracle_db) - rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func - # Extract and Insert into LightRAG storage - with open("./dickens/demo.txt", "r", encoding="utf-8") as f: - await rag.ainsert(f.read()) + with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f: + all_text = f.read() + texts = [x for x in all_text.split("\n") if x] + + # New mode use pipeline + await rag.apipeline_process_documents(texts) + await rag.apipeline_process_chunks() + await rag.apipeline_process_extract_graph() + # Old method use ainsert + #await rag.ainsert(texts) + # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] for mode in modes: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index d464bcc4..c9deed4e 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -12,9 +12,6 @@ from ..base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, - DocStatusStorage, - DocStatus, - DocProcessingStatus, ) import oracledb @@ -156,8 +153,6 @@ class OracleDB: if data is None: await cursor.execute(sql) else: - # print(data) - # print(sql) await cursor.execute(sql, data) await connection.commit() except Exception as e: @@ -175,7 +170,7 @@ class OracleKVStorage(BaseKVStorage): def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config["embedding_batch_num"] + self._max_batch_size = self.global_config.get("embedding_batch_num",10) ################ QUERY METHODS ################ @@ -204,12 +199,11 @@ class OracleKVStorage(BaseKVStorage): array_res = await self.db.query(SQL, params, multirows=True) res = {} for row in array_res: - res[row["id"]] = row + res[row["id"]] = row return res else: return None - - # Query by id + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( @@ -228,8 +222,7 @@ class OracleKVStorage(BaseKVStorage): dict_res[mode] = {} for row in res: dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] - + res = [{k: v} for k, v in dict_res.items()] if res: data = res # [{"data":i} for i in res] # print(data) @@ -237,38 +230,42 @@ class OracleKVStorage(BaseKVStorage): else: return None + async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]: + """Specifically for llm_response_cache.""" + if ids is not None: + SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + else: + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + res = await self.db.query(SQL, params, multirows=True) + if res: + return res + else: + return None + async def filter_keys(self, keys: list[str]) -> set[str]: - """remove duplicated""" + """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) ) params = {"workspace": self.db.workspace} - try: - await self.db.query(SQL, params) - except Exception as e: - logger.error(f"Oracle database error: {e}") - print(SQL) - print(params) res = await self.db.query(SQL, params, multirows=True) - data = None if res: exist_keys = [key["id"] for key in res] data = set([s for s in keys if s not in exist_keys]) + return data else: - exist_keys = [] - data = set([s for s in keys if s not in exist_keys]) - return data + return set(keys) + ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - # print(self._data) - # values = [] if self.namespace == "text_chunks": list_data = [ { - "__id__": k, + "id": k, **{k1: v1 for k1, v1 in v.items()}, } for k, v in data.items() @@ -284,33 +281,30 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - # print(list_data) + + merge_sql = SQL_TEMPLATES["merge_chunk"] for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"] - data = { - "check_id": item["__id__"], - "id": item["__id__"], + _data = { + "id": item["id"], "content": item["content"], "workspace": self.db.workspace, "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], "full_doc_id": item["full_doc_id"], "content_vector": item["__vector__"], + "status": item["status"], } - # print(merge_sql) - await self.db.execute(merge_sql, data) - + await self.db.execute(merge_sql, _data) if self.namespace == "full_docs": - for k, v in self._data.items(): + for k, v in data.items(): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] - data = { + _data = { "id": k, "content": v["content"], "workspace": self.db.workspace, } - # print(merge_sql) - await self.db.execute(merge_sql, data) + await self.db.execute(merge_sql, _data) if self.namespace == "llm_response_cache": for mode, items in data.items(): @@ -325,102 +319,20 @@ class OracleKVStorage(BaseKVStorage): } await self.db.execute(upsert_sql, _data) - return left_data + return None + + async def change_status(self, id: str, status: str): + SQL = SQL_TEMPLATES["change_status"].format( + table_name=N_T[self.namespace] + ) + params = {"workspace": self.db.workspace, "id": id, "status": status} + await self.db.execute(SQL, params) async def index_done_callback(self): if self.namespace in ["full_docs", "text_chunks"]: logger.info("full doc and chunk data had been saved into oracle db!") -@dataclass -class OracleDocStatusStorage(DocStatusStorage): - """Oracle implementation of document status storage""" - # should pass db object to self.db - db: OracleDB = None - meta_fields = None - - def __post_init__(self): - pass - - async def filter_keys(self, ids: list[str]) -> set[str]: - """Return keys that don't exist in storage""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( - ids = ",".join([f"'{id}'" for id in ids]) - ) - params = {"workspace": self.db.workspace} - res = await self.db.query(SQL, params, True) - # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. - if res: - existed = set([element["id"] for element in res]) - return set(ids) - existed - else: - return set(ids) - - async def get_status_counts(self) -> Dict[str, int]: - """Get counts of documents in each status""" - SQL = SQL_TEMPLATES["get_status_counts"] - params = {"workspace": self.db.workspace} - res = await self.db.query(SQL, params, True) - # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] - counts = {} - for doc in res: - counts[doc["status"]] = doc["count"] - return counts - - async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]: - """Get all documents by status""" - SQL = SQL_TEMPLATES["get_docs_by_status"] - params = {"workspace": self.db.workspace, "status": status} - res = await self.db.query(SQL, params, True) - # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] - # Converting to be a dict - return { - element["id"]: DocProcessingStatus( - #content_summary=element["content_summary"], - content_summary = "", - content_length=element["CONTENT_LENGTH"], - status=element["STATUS"], - created_at=element["CREATETIME"], - updated_at=element["UPDATETIME"], - chunks_count=-1, - #chunks_count=element["chunks_count"], - ) - for element in res - } - - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def index_done_callback(self): - """Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here""" - logger.info("Doc status had been saved into ORACLE db!") - - async def upsert(self, data: dict[str, dict]): - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ - SQL = SQL_TEMPLATES["merge_doc_status"] - for k, v in data.items(): - # chunks_count is optional - params = { - "workspace": self.db.workspace, - "id": k, - "content_summary": v["content_summary"], - "content_length": v["content_length"], - "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, - "status": v["status"], - } - await self.db.execute(SQL, params) - return data - - @dataclass class OracleVectorDBStorage(BaseVectorStorage): # should pass db object to self.db @@ -466,7 +378,7 @@ class OracleGraphStorage(BaseGraphStorage): def __post_init__(self): """从graphml文件加载图""" - self._max_batch_size = self.global_config["embedding_batch_num"] + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) #################### insert method ################ @@ -500,7 +412,6 @@ class OracleGraphStorage(BaseGraphStorage): "content": content, "content_vector": content_vector, } - # print(merge_sql) await self.db.execute(merge_sql, data) # self._graph.add_node(node_id, **node_data) @@ -718,9 +629,10 @@ TABLES = { }, "LIGHTRAG_DOC_CHUNKS": { "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - id varchar(256) PRIMARY KEY, + id varchar(256), workspace varchar(1024), full_doc_id varchar(256), + status varchar(256), chunk_order_index NUMBER, tokens NUMBER, content CLOB, @@ -795,9 +707,9 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage - "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", - "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", @@ -808,24 +720,34 @@ SQL_TEMPLATES = { "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", - "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", + "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + + "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", + "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", + + "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", + + "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", + "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", + "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", + "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = :id and a.workspace = :workspace) WHEN NOT MATCHED THEN INSERT(id,content,workspace) values(:id,:content,:workspace)""", - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a + "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + ON (id = :id and workspace = :workspace) + WHEN NOT MATCHED THEN INSERT + (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a USING DUAL @@ -838,27 +760,7 @@ SQL_TEMPLATES = { return_value = :return_value, cache_mode = :cache_mode, updatetime = SYSDATE""", - - "get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})", - "get_status_counts": """SELECT status as "status", COUNT(1) as "count" - FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""", - - "get_docs_by_status": """select content_length,status, - TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime - from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""", - - "merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a - USING DUAL - ON (a.id = :id and a.workspace = :workspace) - WHEN NOT MATCHED THEN - INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status) - WHEN MATCHED THEN UPDATE - SET content_summary = :content_summary, - content_length = :content_length, - chunks_count = :chunks_count, - status = :status, - updatetime = SYSDATE""", # SQL for VectorStorage "entities": """SELECT name as entity_name FROM diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6e2f1c0e..7d8cdf45 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -26,6 +26,7 @@ from .utils import ( convert_response_to_json, logger, set_logger, + statistic_data ) from .base import ( BaseGraphStorage, @@ -36,22 +37,31 @@ from .base import ( DocStatus, ) -from .storage import ( - JsonKVStorage, - NanoVectorDBStorage, - NetworkXStorage, - JsonDocStatusStorage, -) - from .prompt import GRAPH_FIELD_SEP +STORAGES = { + "JsonKVStorage": '.storage', + "NanoVectorDBStorage": '.storage', + "NetworkXStorage": '.storage', + "JsonDocStatusStorage": '.storage', -# future KG integrations - -# from .kg.ArangoDB_impl import ( -# GraphStorage as ArangoDBStorage -# ) - + "Neo4JStorage":".kg.neo4j_impl", + "OracleKVStorage":".kg.oracle_impl", + "OracleGraphStorage":".kg.oracle_impl", + "OracleVectorDBStorage":".kg.oracle_impl", + "MilvusVectorDBStorge":".kg.milvus_impl", + "MongoKVStorage":".kg.mongo_impl", + "ChromaVectorDBStorage":".kg.chroma_impl", + "TiDBKVStorage":".kg.tidb_impl", + "TiDBVectorDBStorage":".kg.tidb_impl", + "TiDBGraphStorage":".kg.tidb_impl", + "PGKVStorage":".kg.postgres_impl", + "PGVectorStorage":".kg.postgres_impl", + "AGEStorage":".kg.age_impl", + "PGGraphStorage":".kg.postgres_impl", + "GremlinStorage":".kg.gremlin_impl", + "PGDocStatusStorage":".kg.postgres_impl", +} def lazy_external_import(module_name: str, class_name: str): """Lazily import a class from an external module based on the package of the caller.""" @@ -65,36 +75,13 @@ def lazy_external_import(module_name: str, class_name: str): def import_class(*args, **kwargs): import importlib - - # Import the module using importlib module = importlib.import_module(module_name, package=package) - - # Get the class from the module and instantiate it cls = getattr(module, class_name) return cls(*args, **kwargs) return import_class -Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") -OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") -OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") -OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") -OracleDocStatusStorage = lazy_external_import(".kg.oracle_impl", "OracleDocStatusStorage") -MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") -MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") -ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") -TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") -TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") -TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage") -PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage") -PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage") -AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage") -PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage") -GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage") -PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage") - - def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ Ensure that there is always an event loop available. @@ -198,52 +185,64 @@ class LightRAG: logger.setLevel(self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") - - _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) - logger.debug(f"LightRAG init with param:\n {_print_config}\n") - - # @TODO: should move all storage setup here to leverage initial start params attached to self. - - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( - self._get_storage_class()[self.kv_storage] - ) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ - self.vector_storage - ] - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ - self.graph_storage - ] - if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) - self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - global_config=asdict(self), - embedding_func=None, - ) + # show config + global_config=asdict(self) + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) + logger.debug(f"LightRAG init with param:\n {_print_config}\n") + # Init LLM self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) + + + # Initialize all storages + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage) + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage) + + self.key_string_value_json_storage_cls = partial( + self.key_string_value_json_storage_cls, + global_config=global_config + ) + + self.vector_db_storage_cls = partial( + self.vector_db_storage_cls, + global_config=global_config + ) + + self.graph_storage_cls = partial( + self.graph_storage_cls, + global_config=global_config + ) + + self.json_doc_status_storage = self.key_string_value_json_storage_cls( + namespace="json_doc_status_storage", + embedding_func=None, + ) + + self.llm_response_cache = self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + embedding_func=None, + ) #### # add embedding func by walter #### self.full_docs = self.key_string_value_json_storage_cls( namespace="full_docs", - global_config=asdict(self), embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( namespace="text_chunks", - global_config=asdict(self), embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( namespace="chunk_entity_relation", - global_config=asdict(self), embedding_func=self.embedding_func, ) #### @@ -252,73 +251,64 @@ class LightRAG: self.entities_vdb = self.vector_db_storage_cls( namespace="entities", - global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) self.relationships_vdb = self.vector_db_storage_cls( namespace="relationships", - global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) self.chunks_vdb = self.vector_db_storage_cls( namespace="chunks", - global_config=asdict(self), embedding_func=self.embedding_func, ) + if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"): + hashing_kv = self.llm_response_cache + else: + hashing_kv = self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + embedding_func=None, + ) + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - global_config=asdict(self), - embedding_func=None, - ), + hashing_kv=hashing_kv, **self.llm_model_kwargs, ) ) # Initialize document status storage - self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage] + self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status = self.doc_status_storage_cls( namespace="doc_status", - global_config=asdict(self), + global_config=global_config, embedding_func=None, ) - def _get_storage_class(self) -> dict: - return { - # kv storage - "JsonKVStorage": JsonKVStorage, - "OracleKVStorage": OracleKVStorage, - "OracleDocStatusStorage":OracleDocStatusStorage, - "MongoKVStorage": MongoKVStorage, - "TiDBKVStorage": TiDBKVStorage, - # vector storage - "NanoVectorDBStorage": NanoVectorDBStorage, - "OracleVectorDBStorage": OracleVectorDBStorage, - "MilvusVectorDBStorge": MilvusVectorDBStorge, - "ChromaVectorDBStorage": ChromaVectorDBStorage, - "TiDBVectorDBStorage": TiDBVectorDBStorage, - # graph storage - "NetworkXStorage": NetworkXStorage, - "Neo4JStorage": Neo4JStorage, - "OracleGraphStorage": OracleGraphStorage, - "AGEStorage": AGEStorage, - "PGGraphStorage": PGGraphStorage, - "PGKVStorage": PGKVStorage, - "PGDocStatusStorage": PGDocStatusStorage, - "PGVectorStorage": PGVectorStorage, - "TiDBGraphStorage": TiDBGraphStorage, - "GremlinStorage": GremlinStorage, - # "ArangoDBStorage": ArangoDBStorage - "JsonDocStatusStorage": JsonDocStatusStorage, - } + def _get_storage_class(self, storage_name: str) -> dict: + import_path = STORAGES[storage_name] + storage_class = lazy_external_import(import_path, storage_name) + return storage_class + + def set_storage_client(self,db_client): + # Now only tested on Oracle Database + for storage in [self.vector_db_storage_cls, + self.graph_storage_cls, + self.doc_status, self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.key_string_value_json_storage_cls, + self.chunks_vdb, + self.relationships_vdb, + self.entities_vdb, + self.graph_storage_cls, + self.chunk_entity_relation_graph, + self.llm_response_cache]: + # set client + storage.db = db_client def insert( self, string_or_strings, split_by_character=None, split_by_character_only=False @@ -358,6 +348,11 @@ class LightRAG: } for content in unique_contents } + + # 3. Store original document and chunks + await self.full_docs.upsert( + {doc_id: {"content": doc["content"]}} + ) # 3. Filter out already processed documents _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) @@ -406,12 +401,7 @@ class LightRAG: } # Update status with chunks information - doc_status.update( - { - "chunks_count": len(chunks), - "updated_at": datetime.now().isoformat(), - } - ) + doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()}) await self.doc_status.upsert({doc_id: doc_status}) try: @@ -435,30 +425,16 @@ class LightRAG: self.chunk_entity_relation_graph = maybe_new_kg - # Store original document and chunks - await self.full_docs.upsert( - {doc_id: {"content": doc["content"]}} - ) + await self.text_chunks.upsert(chunks) # Update status to processed - doc_status.update( - { - "status": DocStatus.PROCESSED, - "updated_at": datetime.now().isoformat(), - } - ) + doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()}) await self.doc_status.upsert({doc_id: doc_status}) except Exception as e: # Mark as failed if any step fails - doc_status.update( - { - "status": DocStatus.FAILED, - "error": str(e), - "updated_at": datetime.now().isoformat(), - } - ) + doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()}) await self.doc_status.upsert({doc_id: doc_status}) raise e @@ -540,6 +516,174 @@ class LightRAG: if update_storage: await self._insert_done() + async def apipeline_process_documents(self, string_or_strings): + """Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs + Args: + string_or_strings: Single document string or list of document strings + """ + if isinstance(string_or_strings, str): + string_or_strings = [string_or_strings] + + # 1. Remove duplicate contents from the list + unique_contents = list(set(doc.strip() for doc in string_or_strings)) + + logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents") + + # 2. Generate document IDs and initial status + new_docs = { + compute_mdhash_id(content, prefix="doc-"): { + "content": content, + "content_summary": self._get_content_summary(content), + "content_length": len(content), + "status": DocStatus.PENDING, + "created_at": datetime.now().isoformat(), + "updated_at": None, + } + for content in unique_contents + } + + # 3. Filter out already processed documents + _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) + if len(_not_stored_doc_keys) < len(new_docs): + logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents") + new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} + + if not new_docs: + logger.info(f"All documents have been processed or are duplicates") + return None + + # 4. Store original document + for doc_id, doc in new_docs.items(): + await self.full_docs.upsert({doc_id: {"content": doc["content"]}}) + await self.full_docs.change_status(doc_id, DocStatus.PENDING) + logger.info(f"Stored {len(new_docs)} new unique documents") + + async def apipeline_process_chunks(self): + """Get pendding documents, split into chunks,insert chunks""" + # 1. get all pending and failed documents + _todo_doc_keys = [] + _failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) + _pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + if _failed_doc: + _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) + if _pendding_doc: + _todo_doc_keys.extend([doc["id"] for doc in _pendding_doc]) + if not _todo_doc_keys: + logger.info("All documents have been processed or are duplicates") + return None + else: + logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") + + new_docs = { + doc["id"]: doc + for doc in await self.full_docs.get_by_ids(_todo_doc_keys) + } + + # 2. split docs into chunks, insert chunks, update doc status + chunk_cnt = 0 + batch_size = self.addon_params.get("insert_batch_size", 10) + for i in range(0, len(new_docs), batch_size): + batch_docs = dict(list(new_docs.items())[i : i + batch_size]) + for doc_id, doc in tqdm_async( + batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}" + ): + try: + # Generate chunks from document + chunks = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, + "status": DocStatus.PENDING, + } + for dp in chunking_by_token_size( + doc["content"], + overlap_token_size=self.chunk_overlap_token_size, + max_token_size=self.chunk_token_size, + tiktoken_model=self.tiktoken_model_name, + ) + } + chunk_cnt += len(chunks) + await self.text_chunks.upsert(chunks) + await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED) + + try: + # Store chunks in vector database + await self.chunks_vdb.upsert(chunks) + # Update doc status + await self.full_docs.change_status(doc_id, DocStatus.PROCESSED) + except Exception as e: + # Mark as failed if any step fails + await self.full_docs.change_status(doc_id, DocStatus.FAILED) + raise e + except Exception as e: + import traceback + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") + + async def apipeline_process_extract_graph(self): + """Get pendding or failed chunks, extract entities and relationships from each chunk""" + # 1. get all pending and failed chunks + _todo_chunk_keys = [] + _failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) + _pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + if _failed_chunks: + _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) + if _pendding_chunks: + _todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks]) + if not _todo_chunk_keys: + logger.info("All chunks have been processed or are duplicates") + return None + + # Process documents in batches + batch_size = self.addon_params.get("insert_batch_size", 10) + + semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously + + async def process_chunk(chunk_id): + async with semaphore: + chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])} + # Extract and store entities and relationships + try: + maybe_new_kg = await extract_entities( + chunks, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + llm_response_cache=self.llm_response_cache, + global_config=asdict(self), + ) + if maybe_new_kg is None: + logger.info("No entities or relationships extracted!") + # Update status to processed + await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED) + except Exception as e: + logger.error("Failed to extract entities and relationships") + # Mark as failed if any step fails + await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) + raise e + + with tqdm_async(total=len(_todo_chunk_keys), + desc="\nLevel 1 - Processing chunks", + unit="chunk", + position=0) as progress: + tasks = [] + for chunk_id in _todo_chunk_keys: + task = asyncio.create_task(process_chunk(chunk_id)) + tasks.append(task) + + for future in asyncio.as_completed(tasks): + await future + progress.update(1) + progress.set_postfix({ + 'LLM call': statistic_data["llm_call"], + 'LLM cache': statistic_data["llm_cache"], + }) + + # Ensure all indexes are updated after each document + await self._insert_done() + async def _insert_done(self): tasks = [] for storage_inst in [ diff --git a/lightrag/operate.py b/lightrag/operate.py index 97ac245c..f9e48dbf 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -20,6 +20,7 @@ from .utils import ( handle_cache, save_to_cache, CacheData, + statistic_data ) from .base import ( BaseGraphStorage, @@ -371,14 +372,16 @@ async def extract_entities( if need_to_restore: llm_response_cache.global_config = global_config if cached_return: + logger.debug(f"Found cache for {arg_hash}") + statistic_data["llm_cache"] += 1 return cached_return - + statistic_data["llm_call"] += 1 if history_messages: res: str = await use_llm_func( input_text, history_messages=history_messages ) else: - res: str = await use_llm_func(input_text) + res: str = await use_llm_func(input_text) await save_to_cache( llm_response_cache, CacheData(args_hash=arg_hash, content=res, prompt=_prompt), @@ -459,10 +462,8 @@ async def extract_entities( now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) ] - print( + logger.debug( f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", - end="", - flush=True, ) return dict(maybe_nodes), dict(maybe_edges) @@ -470,8 +471,8 @@ async def extract_entities( for result in tqdm_async( asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), total=len(ordered_chunks), - desc="Extracting entities from chunks", - unit="chunk", + desc="Level 2 - Extracting entities and relationships", + unit="chunk", position=1,leave=False ): results.append(await result) @@ -482,7 +483,7 @@ async def extract_entities( maybe_nodes[k].extend(v) for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) - logger.info("Inserting entities into storage...") + logger.debug("Inserting entities into storage...") all_entities_data = [] for result in tqdm_async( asyncio.as_completed( @@ -492,12 +493,12 @@ async def extract_entities( ] ), total=len(maybe_nodes), - desc="Inserting entities", - unit="entity", + desc="Level 3 - Inserting entities", + unit="entity", position=2,leave=False ): all_entities_data.append(await result) - logger.info("Inserting relationships into storage...") + logger.debug("Inserting relationships into storage...") all_relationships_data = [] for result in tqdm_async( asyncio.as_completed( @@ -509,8 +510,8 @@ async def extract_entities( ] ), total=len(maybe_edges), - desc="Inserting relationships", - unit="relationship", + desc="Level 3 - Inserting relationships", + unit="relationship", position=3,leave=False ): all_relationships_data.append(await result) diff --git a/lightrag/utils.py b/lightrag/utils.py index 56e4191c..a83c0382 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -30,8 +30,13 @@ class UnlimitedSemaphore: ENCODER = None +statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} + logger = logging.getLogger("lightrag") +# Set httpx logging level to WARNING +logging.getLogger("httpx").setLevel(logging.WARNING) + def set_logger(log_file: str): logger.setLevel(logging.DEBUG) @@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): return None, None, None, None # For naive mode, only use simple cache matching - if mode == "naive": + #if mode == "naive": + if mode == "default": if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: From 6ae8647285abcdfb6875f667c5c5daa1c7409fa8 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:58:15 +0800 Subject: [PATCH 08/12] support pipeline mode --- .gitignore | 2 +- examples/lightrag_oracle_demo.py | 44 +++--- lightrag/kg/oracle_impl.py | 52 +++---- lightrag/lightrag.py | 245 ++++++++++++++++++------------- lightrag/operate.py | 24 ++- lightrag/utils.py | 8 +- 6 files changed, 203 insertions(+), 172 deletions(-) diff --git a/.gitignore b/.gitignore index 0e0ec299..ec95f8a5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,4 @@ rag_storage venv/ examples/input/ examples/output/ -.DS_Store \ No newline at end of file +.DS_Store diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 8a5439e2..6de6e0a7 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -89,49 +89,45 @@ async def main(): rag = LightRAG( # log_level="DEBUG", working_dir=WORKING_DIR, - entity_extract_max_gleaning = 1, - + entity_extract_max_gleaning=1, enable_llm_cache=True, - enable_llm_cache_for_entity_extract = True, - embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, - - + enable_llm_cache_for_entity_extract=True, + embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90}, chunk_token_size=CHUNK_TOKEN_SIZE, - llm_model_max_token_size = MAX_TOKENS, + llm_model_max_token_size=MAX_TOKENS, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dimension, max_token_size=500, func=embedding_func, - ), - - graph_storage = "OracleGraphStorage", - kv_storage = "OracleKVStorage", + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", vector_storage="OracleVectorDBStorage", - - addon_params = {"example_number":1, - "language":"Simplfied Chinese", - "entity_types": ["organization", "person", "geo", "event"], - "insert_batch_size":2, - } + addon_params={ + "example_number": 1, + "language": "Simplfied Chinese", + "entity_types": ["organization", "person", "geo", "event"], + "insert_batch_size": 2, + }, ) - # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.set_storage_client(db_client = oracle_db) + # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool + rag.set_storage_client(db_client=oracle_db) # Extract and Insert into LightRAG storage - with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f: + with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: all_text = f.read() texts = [x for x in all_text.split("\n") if x] - + # New mode use pipeline await rag.apipeline_process_documents(texts) - await rag.apipeline_process_chunks() + await rag.apipeline_process_chunks() await rag.apipeline_process_extract_graph() # Old method use ainsert - #await rag.ainsert(texts) - + # await rag.ainsert(texts) + # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] for mode in modes: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index c9deed4e..e30b6909 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -3,7 +3,7 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Union, List, Dict, Set, Any, Tuple +from typing import Union import numpy as np import array @@ -170,7 +170,7 @@ class OracleKVStorage(BaseKVStorage): def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config.get("embedding_batch_num",10) + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) ################ QUERY METHODS ################ @@ -190,7 +190,7 @@ class OracleKVStorage(BaseKVStorage): return res else: return None - + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] @@ -199,11 +199,11 @@ class OracleKVStorage(BaseKVStorage): array_res = await self.db.query(SQL, params, multirows=True) res = {} for row in array_res: - res[row["id"]] = row + res[row["id"]] = row return res else: return None - + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( @@ -222,7 +222,7 @@ class OracleKVStorage(BaseKVStorage): dict_res[mode] = {} for row in res: dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] + res = [{k: v} for k, v in dict_res.items()] if res: data = res # [{"data":i} for i in res] # print(data) @@ -230,7 +230,9 @@ class OracleKVStorage(BaseKVStorage): else: return None - async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]: + async def get_by_status_and_ids( + self, status: str, ids: list[str] + ) -> Union[list[dict], None]: """Specifically for llm_response_cache.""" if ids is not None: SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( @@ -244,7 +246,7 @@ class OracleKVStorage(BaseKVStorage): return res else: return None - + async def filter_keys(self, keys: list[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( @@ -258,7 +260,6 @@ class OracleKVStorage(BaseKVStorage): return data else: return set(keys) - ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): @@ -281,7 +282,7 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - + merge_sql = SQL_TEMPLATES["merge_chunk"] for item in list_data: _data = { @@ -320,11 +321,9 @@ class OracleKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) return None - + async def change_status(self, id: str, status: str): - SQL = SQL_TEMPLATES["change_status"].format( - table_name=N_T[self.namespace] - ) + SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace]) params = {"workspace": self.db.workspace, "id": id, "status": status} await self.db.execute(SQL, params) @@ -673,8 +672,8 @@ TABLES = { }, "LIGHTRAG_LLM_CACHE": { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - id varchar(256) PRIMARY KEY, - workspace varchar(1024), + id varchar(256) PRIMARY KEY, + workspace varchar(1024), cache_mode varchar(256), model_name varchar(256), original_prompt clob, @@ -708,47 +707,32 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", - "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", - "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", - "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", - "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", - "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", - "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", - "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", - "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", - "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = :id and a.workspace = :workspace) WHEN NOT MATCHED THEN INSERT(id,content,workspace) values(:id,:content,:workspace)""", - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS USING DUAL ON (id = :id and workspace = :workspace) WHEN NOT MATCHED THEN INSERT (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, - "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a USING DUAL ON (a.id = :id) @@ -760,8 +744,6 @@ SQL_TEMPLATES = { return_value = :return_value, cache_mode = :cache_mode, updatetime = SYSDATE""", - - # SQL for VectorStorage "entities": """SELECT name as entity_name FROM (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance @@ -818,7 +800,7 @@ SQL_TEMPLATES = { INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) WHEN MATCHED THEN - UPDATE SET + UPDATE SET entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL @@ -827,7 +809,7 @@ SQL_TEMPLATES = { INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) WHEN MATCHED THEN - UPDATE SET + UPDATE SET weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7d8cdf45..0902fc50 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -26,7 +26,7 @@ from .utils import ( convert_response_to_json, logger, set_logger, - statistic_data + statistic_data, ) from .base import ( BaseGraphStorage, @@ -39,30 +39,30 @@ from .base import ( from .prompt import GRAPH_FIELD_SEP -STORAGES = { - "JsonKVStorage": '.storage', - "NanoVectorDBStorage": '.storage', - "NetworkXStorage": '.storage', - "JsonDocStatusStorage": '.storage', - - "Neo4JStorage":".kg.neo4j_impl", - "OracleKVStorage":".kg.oracle_impl", - "OracleGraphStorage":".kg.oracle_impl", - "OracleVectorDBStorage":".kg.oracle_impl", - "MilvusVectorDBStorge":".kg.milvus_impl", - "MongoKVStorage":".kg.mongo_impl", - "ChromaVectorDBStorage":".kg.chroma_impl", - "TiDBKVStorage":".kg.tidb_impl", - "TiDBVectorDBStorage":".kg.tidb_impl", - "TiDBGraphStorage":".kg.tidb_impl", - "PGKVStorage":".kg.postgres_impl", - "PGVectorStorage":".kg.postgres_impl", - "AGEStorage":".kg.age_impl", - "PGGraphStorage":".kg.postgres_impl", - "GremlinStorage":".kg.gremlin_impl", - "PGDocStatusStorage":".kg.postgres_impl", +STORAGES = { + "JsonKVStorage": ".storage", + "NanoVectorDBStorage": ".storage", + "NetworkXStorage": ".storage", + "JsonDocStatusStorage": ".storage", + "Neo4JStorage": ".kg.neo4j_impl", + "OracleKVStorage": ".kg.oracle_impl", + "OracleGraphStorage": ".kg.oracle_impl", + "OracleVectorDBStorage": ".kg.oracle_impl", + "MilvusVectorDBStorge": ".kg.milvus_impl", + "MongoKVStorage": ".kg.mongo_impl", + "ChromaVectorDBStorage": ".kg.chroma_impl", + "TiDBKVStorage": ".kg.tidb_impl", + "TiDBVectorDBStorage": ".kg.tidb_impl", + "TiDBGraphStorage": ".kg.tidb_impl", + "PGKVStorage": ".kg.postgres_impl", + "PGVectorStorage": ".kg.postgres_impl", + "AGEStorage": ".kg.age_impl", + "PGGraphStorage": ".kg.postgres_impl", + "GremlinStorage": ".kg.gremlin_impl", + "PGDocStatusStorage": ".kg.postgres_impl", } + def lazy_external_import(module_name: str, class_name: str): """Lazily import a class from an external module based on the package of the caller.""" @@ -75,6 +75,7 @@ def lazy_external_import(module_name: str, class_name: str): def import_class(*args, **kwargs): import importlib + module = importlib.import_module(module_name, package=package) cls = getattr(module, class_name) return cls(*args, **kwargs) @@ -190,7 +191,7 @@ class LightRAG: os.makedirs(self.working_dir) # show config - global_config=asdict(self) + global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -198,31 +199,33 @@ class LightRAG: self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) - # Initialize all storages - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage) - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage) - + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self._get_storage_class(self.kv_storage) + ) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( + self.vector_storage + ) + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( + self.graph_storage + ) + self.key_string_value_json_storage_cls = partial( - self.key_string_value_json_storage_cls, - global_config=global_config + self.key_string_value_json_storage_cls, global_config=global_config ) self.vector_db_storage_cls = partial( - self.vector_db_storage_cls, - global_config=global_config + self.vector_db_storage_cls, global_config=global_config ) self.graph_storage_cls = partial( - self.graph_storage_cls, - global_config=global_config + self.graph_storage_cls, global_config=global_config ) self.json_doc_status_storage = self.key_string_value_json_storage_cls( namespace="json_doc_status_storage", - embedding_func=None, + embedding_func=None, ) self.llm_response_cache = self.key_string_value_json_storage_cls( @@ -264,13 +267,15 @@ class LightRAG: embedding_func=self.embedding_func, ) - if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"): + if self.llm_response_cache and hasattr( + self.llm_response_cache, "global_config" + ): hashing_kv = self.llm_response_cache else: hashing_kv = self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - embedding_func=None, - ) + namespace="llm_response_cache", + embedding_func=None, + ) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -292,21 +297,24 @@ class LightRAG: import_path = STORAGES[storage_name] storage_class = lazy_external_import(import_path, storage_name) return storage_class - - def set_storage_client(self,db_client): + + def set_storage_client(self, db_client): # Now only tested on Oracle Database - for storage in [self.vector_db_storage_cls, - self.graph_storage_cls, - self.doc_status, self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.key_string_value_json_storage_cls, - self.chunks_vdb, - self.relationships_vdb, - self.entities_vdb, - self.graph_storage_cls, - self.chunk_entity_relation_graph, - self.llm_response_cache]: + for storage in [ + self.vector_db_storage_cls, + self.graph_storage_cls, + self.doc_status, + self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.key_string_value_json_storage_cls, + self.chunks_vdb, + self.relationships_vdb, + self.entities_vdb, + self.graph_storage_cls, + self.chunk_entity_relation_graph, + self.llm_response_cache, + ]: # set client storage.db = db_client @@ -348,11 +356,6 @@ class LightRAG: } for content in unique_contents } - - # 3. Store original document and chunks - await self.full_docs.upsert( - {doc_id: {"content": doc["content"]}} - ) # 3. Filter out already processed documents _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) @@ -401,7 +404,12 @@ class LightRAG: } # Update status with chunks information - doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "chunks_count": len(chunks), + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) try: @@ -425,16 +433,30 @@ class LightRAG: self.chunk_entity_relation_graph = maybe_new_kg - + # Store original document and chunks + await self.full_docs.upsert( + {doc_id: {"content": doc["content"]}} + ) await self.text_chunks.upsert(chunks) # Update status to processed - doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "status": DocStatus.PROCESSED, + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) except Exception as e: # Mark as failed if any step fails - doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "status": DocStatus.FAILED, + "error": str(e), + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) raise e @@ -527,7 +549,9 @@ class LightRAG: # 1. Remove duplicate contents from the list unique_contents = list(set(doc.strip() for doc in string_or_strings)) - logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents") + logger.info( + f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents" + ) # 2. Generate document IDs and initial status new_docs = { @@ -542,28 +566,34 @@ class LightRAG: for content in unique_contents } - # 3. Filter out already processed documents + # 3. Filter out already processed documents _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) if len(_not_stored_doc_keys) < len(new_docs): - logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents") + logger.info( + f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents" + ) new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} if not new_docs: - logger.info(f"All documents have been processed or are duplicates") + logger.info("All documents have been processed or are duplicates") return None - # 4. Store original document + # 4. Store original document for doc_id, doc in new_docs.items(): await self.full_docs.upsert({doc_id: {"content": doc["content"]}}) await self.full_docs.change_status(doc_id, DocStatus.PENDING) logger.info(f"Stored {len(new_docs)} new unique documents") - + async def apipeline_process_chunks(self): - """Get pendding documents, split into chunks,insert chunks""" - # 1. get all pending and failed documents + """Get pendding documents, split into chunks,insert chunks""" + # 1. get all pending and failed documents _todo_doc_keys = [] - _failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) - _pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + _failed_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) if _failed_doc: _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) if _pendding_doc: @@ -573,10 +603,9 @@ class LightRAG: return None else: logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") - + new_docs = { - doc["id"]: doc - for doc in await self.full_docs.get_by_ids(_todo_doc_keys) + doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys) } # 2. split docs into chunks, insert chunks, update doc status @@ -585,8 +614,9 @@ class LightRAG: for i in range(0, len(new_docs), batch_size): batch_docs = dict(list(new_docs.items())[i : i + batch_size]) for doc_id, doc in tqdm_async( - batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}" - ): + batch_docs.items(), + desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}", + ): try: # Generate chunks from document chunks = { @@ -616,18 +646,23 @@ class LightRAG: await self.full_docs.change_status(doc_id, DocStatus.FAILED) raise e except Exception as e: - import traceback - error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" - logger.error(error_msg) - continue - logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") + import traceback + + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") async def apipeline_process_extract_graph(self): """Get pendding or failed chunks, extract entities and relationships from each chunk""" - # 1. get all pending and failed chunks + # 1. get all pending and failed chunks _todo_chunk_keys = [] - _failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) - _pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + _failed_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) if _failed_chunks: _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) if _pendding_chunks: @@ -635,15 +670,19 @@ class LightRAG: if not _todo_chunk_keys: logger.info("All chunks have been processed or are duplicates") return None - + # Process documents in batches batch_size = self.addon_params.get("insert_batch_size", 10) - semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously + semaphore = asyncio.Semaphore( + batch_size + ) # Control the number of tasks that are processed simultaneously - async def process_chunk(chunk_id): + async def process_chunk(chunk_id): async with semaphore: - chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])} + chunks = { + i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) + } # Extract and store entities and relationships try: maybe_new_kg = await extract_entities( @@ -662,25 +701,29 @@ class LightRAG: logger.error("Failed to extract entities and relationships") # Mark as failed if any step fails await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) - raise e + raise e - with tqdm_async(total=len(_todo_chunk_keys), - desc="\nLevel 1 - Processing chunks", - unit="chunk", - position=0) as progress: + with tqdm_async( + total=len(_todo_chunk_keys), + desc="\nLevel 1 - Processing chunks", + unit="chunk", + position=0, + ) as progress: tasks = [] for chunk_id in _todo_chunk_keys: task = asyncio.create_task(process_chunk(chunk_id)) tasks.append(task) - + for future in asyncio.as_completed(tasks): await future progress.update(1) - progress.set_postfix({ - 'LLM call': statistic_data["llm_call"], - 'LLM cache': statistic_data["llm_cache"], - }) - + progress.set_postfix( + { + "LLM call": statistic_data["llm_call"], + "LLM cache": statistic_data["llm_cache"], + } + ) + # Ensure all indexes are updated after each document await self._insert_done() diff --git a/lightrag/operate.py b/lightrag/operate.py index f9e48dbf..fbfd1fbd 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -20,7 +20,7 @@ from .utils import ( handle_cache, save_to_cache, CacheData, - statistic_data + statistic_data, ) from .base import ( BaseGraphStorage, @@ -105,7 +105,9 @@ async def _handle_entity_relation_summary( llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] summary_max_tokens = global_config["entity_summary_to_max_tokens"] - language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) if len(tokens) < summary_max_tokens: # No need for summary @@ -360,7 +362,7 @@ async def extract_entities( llm_response_cache.global_config = new_config need_to_restore = True if history_messages: - history = json.dumps(history_messages,ensure_ascii=False) + history = json.dumps(history_messages, ensure_ascii=False) _prompt = history + "\n" + input_text else: _prompt = input_text @@ -381,7 +383,7 @@ async def extract_entities( input_text, history_messages=history_messages ) else: - res: str = await use_llm_func(input_text) + res: str = await use_llm_func(input_text) await save_to_cache( llm_response_cache, CacheData(args_hash=arg_hash, content=res, prompt=_prompt), @@ -394,7 +396,7 @@ async def extract_entities( return await use_llm_func(input_text) async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): - """"Prpocess a single chunk + """ "Prpocess a single chunk Args: chunk_key_dp (tuple[str, TextChunkSchema]): ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) @@ -472,7 +474,9 @@ async def extract_entities( asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), total=len(ordered_chunks), desc="Level 2 - Extracting entities and relationships", - unit="chunk", position=1,leave=False + unit="chunk", + position=1, + leave=False, ): results.append(await result) @@ -494,7 +498,9 @@ async def extract_entities( ), total=len(maybe_nodes), desc="Level 3 - Inserting entities", - unit="entity", position=2,leave=False + unit="entity", + position=2, + leave=False, ): all_entities_data.append(await result) @@ -511,7 +517,9 @@ async def extract_entities( ), total=len(maybe_edges), desc="Level 3 - Inserting relationships", - unit="relationship", position=3,leave=False + unit="relationship", + position=3, + leave=False, ): all_relationships_data.append(await result) diff --git a/lightrag/utils.py b/lightrag/utils.py index a83c0382..ce556ab2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -41,7 +41,7 @@ logging.getLogger("httpx").setLevel(logging.WARNING) def set_logger(log_file: str): logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -458,7 +458,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): return None, None, None, None # For naive mode, only use simple cache matching - #if mode == "naive": + # if mode == "naive": if mode == "default": if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} @@ -479,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"] + embedding_model_func = hashing_kv.global_config[ + "embedding_func" + ].func # ["func"] llm_model_func = hashing_kv.global_config.get("llm_model_func") current_embedding = await embedding_model_func([prompt]) From dd105d47fa031e2ed9ec3403b578ba3d154d271a Mon Sep 17 00:00:00 2001 From: Gurjot Singh Date: Thu, 16 Jan 2025 11:15:21 +0530 Subject: [PATCH 09/12] Update README.md to include a detailed explanation of the new query_with_separate_keyword_extraction function. --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index da1a1d56..233fe56d 100644 --- a/README.md +++ b/README.md @@ -330,6 +330,26 @@ rag = LightRAG( with open("./newText.txt") as f: rag.insert(f.read()) ``` +### Separate Keyword Extraction +We've introduced a new function `query_with_separate_keyword_extraction` to enhance the keyword extraction capabilities. This function separates the keyword extraction process from the user's prompt, focusing solely on the query to improve the relevance of extracted keywords. + +##### How It Works? +The function operates by dividing the input into two parts: +- `User Query` +- `Prompt` + +It then performs keyword extraction exclusively on the `user query`. This separation ensures that the extraction process is focused and relevant, unaffected by any additional language in the `prompt`. It also allows the `prompt` to serve purely for response formatting, maintaining the intent and clarity of the user's original question. + +##### Usage Example +This `example` shows how to tailor the function for educational content, focusing on detailed explanations for older students. + +```python +rag.query_with_separate_keyword_extraction( + query="Explain the law of gravity", + prompt="Provide a detailed explanation suitable for high school students studying physics.", + param=QueryParam(mode="hybrid") +) +``` ### Using Neo4J for Storage From e64805c9e2e08c8b27355b9429c4537fbe58a4d6 Mon Sep 17 00:00:00 2001 From: Gurjot Singh Date: Thu, 16 Jan 2025 11:26:19 +0530 Subject: [PATCH 10/12] Add example usage for separate keyword extraction of user's query --- examples/query_keyword_separation_example.py | 114 +++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 examples/query_keyword_separation_example.py diff --git a/examples/query_keyword_separation_example.py b/examples/query_keyword_separation_example.py new file mode 100644 index 00000000..43606834 --- /dev/null +++ b/examples/query_keyword_separation_example.py @@ -0,0 +1,114 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.utils import EmbeddingFunc +import numpy as np +from dotenv import load_dotenv +import logging +from openai import AzureOpenAI + +logging.basicConfig(level=logging.INFO) + +load_dotenv() + +AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") +AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") +AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") +AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") + +AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT") +AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION") + +WORKING_DIR = "./dickens" + +if os.path.exists(WORKING_DIR): + import shutil + + shutil.rmtree(WORKING_DIR) + +os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_OPENAI_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + chat_completion = client.chat.completions.create( + model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name". + messages=messages, + temperature=kwargs.get("temperature", 0), + top_p=kwargs.get("top_p", 1), + n=kwargs.get("n", 1), + ) + return chat_completion.choices[0].message.content + + +async def embedding_func(texts: list[str]) -> np.ndarray: + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_EMBEDDING_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts) + + embeddings = [item.embedding for item in embedding.data] + return np.array(embeddings) + + +async def test_funcs(): + result = await llm_model_func("How are you?") + print("Resposta do llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("Resultado do embedding_func: ", result.shape) + print("Dimensão da embedding: ", result.shape[1]) + + +asyncio.run(test_funcs()) + +embedding_dimension = 3072 + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), +) + +book1 = open("./book_1.txt", encoding="utf-8") +book2 = open("./book_2.txt", encoding="utf-8") + +rag.insert([book1.read(), book2.read()]) + +# Example function demonstrating the new query_with_separate_keyword_extraction usage +async def run_example(): + query = "What are the top themes in this story?" + prompt = "Please simplify the response for a young audience." + + # Using the new method to ensure the keyword extraction is only applied to the query + response = rag.query_with_separate_keyword_extraction( + query=query, + prompt=prompt, + param=QueryParam(mode="hybrid") # Adjust QueryParam mode as necessary + ) + + print("Extracted Response:", response) + +# Run the example asynchronously +if __name__ == "__main__": + asyncio.run(run_example()) From 2ea104d738d64e39525099a191410f2f86f695cd Mon Sep 17 00:00:00 2001 From: Gurjot Singh Date: Thu, 16 Jan 2025 11:31:22 +0530 Subject: [PATCH 11/12] Fix linting errors --- README.md | 2 +- examples/query_keyword_separation_example.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 233fe56d..8a0da666 100644 --- a/README.md +++ b/README.md @@ -334,7 +334,7 @@ with open("./newText.txt") as f: We've introduced a new function `query_with_separate_keyword_extraction` to enhance the keyword extraction capabilities. This function separates the keyword extraction process from the user's prompt, focusing solely on the query to improve the relevance of extracted keywords. ##### How It Works? -The function operates by dividing the input into two parts: +The function operates by dividing the input into two parts: - `User Query` - `Prompt` diff --git a/examples/query_keyword_separation_example.py b/examples/query_keyword_separation_example.py index 43606834..f11ce8c1 100644 --- a/examples/query_keyword_separation_example.py +++ b/examples/query_keyword_separation_example.py @@ -95,20 +95,22 @@ book2 = open("./book_2.txt", encoding="utf-8") rag.insert([book1.read(), book2.read()]) + # Example function demonstrating the new query_with_separate_keyword_extraction usage async def run_example(): query = "What are the top themes in this story?" prompt = "Please simplify the response for a young audience." - + # Using the new method to ensure the keyword extraction is only applied to the query response = rag.query_with_separate_keyword_extraction( query=query, prompt=prompt, - param=QueryParam(mode="hybrid") # Adjust QueryParam mode as necessary + param=QueryParam(mode="hybrid"), # Adjust QueryParam mode as necessary ) - + print("Extracted Response:", response) + # Run the example asynchronously if __name__ == "__main__": asyncio.run(run_example()) From d7cfe029ebcd645a180ae8e488fe66b080cf0550 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:24:29 +0800 Subject: [PATCH 12/12] Update __init__.py --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 7a26a282..b6317d84 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.1.1" +__version__ = "1.1.2" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG"