From 957bcf865998e6f806eb02c45cfc5ec243ebeb75 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Tue, 7 Jan 2025 13:51:20 +0800 Subject: [PATCH 1/4] Organize files move some test files from root to example --- .DS_Store | Bin 8196 -> 0 bytes .gitignore | 1 + .../get_all_edges_nx.py | 0 test.py => examples/test.py | 0 test_chromadb.py => examples/test_chromadb.py | 0 test_neo4j.py => examples/test_neo4j.py | 0 6 files changed, 1 insertion(+) delete mode 100644 .DS_Store rename get_all_edges_nx.py => examples/get_all_edges_nx.py (100%) rename test.py => examples/test.py (100%) rename test_chromadb.py => examples/test_chromadb.py (100%) rename test_neo4j.py => examples/test_neo4j.py (100%) diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 7489d923a9e7375a0aadb3d45e73e969d67f761d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMTWl0n7(U;$(3#;jz3G5m*ica{3$&KdV$p0{3Y034-f7Wwc4wqxr!!?|b_)_> zjrt@S6K~N(LcBcqAc-1CG-@>AfvEApG)7G1#Y7+U!Ap#qVDLY4W(jTK!GstR<|OC* z^Pm5G=klNLpR;F{F@}zU*~FO67?UY;sV$=73W?i!o|mMkrko@Q&zQwr=4Yq;$*eQ4 zbzBj9Aof7)f!G7F2VxIg4<4X1nJAnagw}r-ck(9aMx@ z0FwL)5GH!2dq9XuMlv1AX(2->^(k%-h@gnB7!Yu>CwX&{=}1lsDc}qO&Jc}^=!SxD zbiyUW<_u{e<1+R@?1A|n;NerjOlGqj8^8Si?q$jHes&XC>4t; z?(ZA(ruqUSDEJ*lew1tdp66ss&qj98wx%*#O{eDuhV5GU5MY}YHQ7I8yQVkQ;T2rd z53zm1peVAE(Uu=M(%iHqm27TlIhslyX=-e3P9-Zv$A9Jm*x1y&aA_iC0!^j>zim7k zjomrV-J7?#eJJNSg$dWsXq7p`$rT(Uu-YA`&wj-6hXdgYle!pZ)Y8AV9GXw zv3-W0vs}}54~&ol-N1gxw(?3PPAL7ZF=5SUOP4LV>1JIYE@bV41v@yb)yuNzt4?33 zXs6`;CXIieXL6(&TJ$=7wJeYF`p&#XrKncd=(j7HA@}>V1kpu-%_T?WSp&(aQC` zdatZ~MHtO?i#Ac1wpTD&u1`0Joe14mpKcVILI=}wwt~?&+HHOC;N;24GCpffQ6_=9THfH1~g+6Hlqbw(S_~UiEix1 z5QZ^=Q5bj_)`S8ZY7{yo^`y7T(2sIEB;r6rbU9e1Y@$5x?R${EiFw z2NxwKRZBHeLRu}Y6VsMTkuf}*UQ&wmIRQ_RK_BeG$T)bV$VhD8vGe+fjB|^5EL2ui z*VL}$0kXL*eThbbcqOh~&f$=@v&bIOHr(gAe4~h2zTy@&F+jYw4$Y0UQUq!hoWb%8 zZ@nt1CbR{_e4*aDhSY_`dZFIYkVt3}1)ESeY)q(H8HJJ(Vw)RCMWnQ}s;X8_K_zTk z?toeTRL(K4E9s3{mw5_7nSsU4)FKs3BtBhP7CS6xI_tTX7fK zLW1rfdhS9G`Y}ip-A5$N!NLKI!A2ed3L#M+3yJz9p24$t4$tES9LEW~hS%{1-o!~F z?T2&dxN9~Y@zH!b7W?w9dqk Date: Fri, 10 Jan 2025 11:36:28 +0800 Subject: [PATCH 2/4] update Oracle support add cache support, fix bug --- examples/lightrag_oracle_demo.py | 45 ++++-- lightrag/kg/oracle_impl.py | 265 +++++++++++++++++++++++++++---- lightrag/lightrag.py | 2 + lightrag/operate.py | 16 +- lightrag/utils.py | 4 +- 5 files changed, 284 insertions(+), 48 deletions(-) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index bbb69319..2cc8e018 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" APIKEY = "ocigenerativeai" CHATMODEL = "cohere.command-r-plus" EMBEDMODEL = "cohere.embed-multilingual-v3.0" - +CHUNK_TOKEN_SIZE = 1024 +MAX_TOKENS = 4000 if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -86,27 +87,49 @@ async def main(): # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( + working_dir=WORKING_DIR, + entity_extract_max_gleaning = 1, + enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, + embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, + enable_llm_cache_for_entity_extract = True, + + chunk_token_size=CHUNK_TOKEN_SIZE, + llm_model_max_token_size = MAX_TOKENS, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dimension, - max_token_size=512, + max_token_size=500, func=embedding_func, - ), - graph_storage="OracleGraphStorage", - kv_storage="OracleKVStorage", + ), + + graph_storage = "OracleGraphStorage", + kv_storage = "OracleKVStorage", vector_storage="OracleVectorDBStorage", + doc_status_storage="OracleDocStatusStorage", + + addon_params = {"example_number":1, "language":"Simplfied Chinese"}, ) - # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.graph_storage_cls.db = oracle_db + # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.key_string_value_json_storage_cls.db = oracle_db rag.vector_db_storage_cls.db = oracle_db - # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c - rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + rag.graph_storage_cls.db = oracle_db + rag.doc_status_storage_cls.db = oracle_db + rag.doc_status.db = oracle_db + rag.full_docs.db = oracle_db + rag.text_chunks.db = oracle_db + rag.llm_response_cache.db = oracle_db + rag.key_string_value_json_storage_cls.db = oracle_db + rag.chunks_vdb.db = oracle_db + rag.relationships_vdb.db = oracle_db + rag.entities_vdb.db = oracle_db + rag.graph_storage_cls.db = oracle_db + rag.chunk_entity_relation_graph.db = oracle_db + rag.llm_response_cache.db = oracle_db + rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + # Extract and Insert into LightRAG storage with open("./dickens/demo.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 34745312..d464bcc4 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -3,7 +3,7 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Union +from typing import Union, List, Dict, Set, Any, Tuple import numpy as np import array @@ -12,6 +12,9 @@ from ..base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, + DocStatusStorage, + DocStatus, + DocProcessingStatus, ) import oracledb @@ -167,6 +170,9 @@ class OracleDB: @dataclass class OracleKVStorage(BaseKVStorage): # should pass db object to self.db + db: OracleDB = None + meta_fields = None + def __post_init__(self): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] @@ -174,28 +180,56 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> Union[dict, None]: - """根据 id 获取 doc_full 数据.""" + """get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL, params) + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + else: + res = await self.db.query(SQL, params) if res: - data = res # {"data":res} - # print (data) - return data + return res + else: + return None + + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + """Specifically for llm_response_cache.""" + SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res else: return None # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: - """根据 id 获取 doc_chunks 数据""" + """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} # print("get_by_ids:"+SQL) - # print(params) res = await self.db.query(SQL, params, multirows=True) + if "llm_response_cache" == self.namespace: + modes = set() + dict_res: dict[str, dict] = {} + for row in res: + modes.add(row["mode"]) + for mode in modes: + if mode not in dict_res: + dict_res[mode] = {} + for row in res: + dict_res[row["mode"]][row["id"]] = row + res = [{k: v} for k, v in dict_res.items()] + if res: data = res # [{"data":i} for i in res] # print(data) @@ -204,7 +238,7 @@ class OracleKVStorage(BaseKVStorage): return None async def filter_keys(self, keys: list[str]) -> set[str]: - """过滤掉重复内容""" + """remove duplicated""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) ) @@ -271,13 +305,26 @@ class OracleKVStorage(BaseKVStorage): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] data = { - "check_id": k, "id": k, "content": v["content"], "workspace": self.db.workspace, } # print(merge_sql) await self.db.execute(merge_sql, data) + + if self.namespace == "llm_response_cache": + for mode, items in data.items(): + for k, v in items.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "cache_mode": mode, + } + + await self.db.execute(upsert_sql, _data) return left_data async def index_done_callback(self): @@ -285,8 +332,99 @@ class OracleKVStorage(BaseKVStorage): logger.info("full doc and chunk data had been saved into oracle db!") +@dataclass +class OracleDocStatusStorage(DocStatusStorage): + """Oracle implementation of document status storage""" + # should pass db object to self.db + db: OracleDB = None + meta_fields = None + + def __post_init__(self): + pass + + async def filter_keys(self, ids: list[str]) -> set[str]: + """Return keys that don't exist in storage""" + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( + ids = ",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + res = await self.db.query(SQL, params, True) + # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. + if res: + existed = set([element["id"] for element in res]) + return set(ids) - existed + else: + return set(ids) + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + SQL = SQL_TEMPLATES["get_status_counts"] + params = {"workspace": self.db.workspace} + res = await self.db.query(SQL, params, True) + # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] + counts = {} + for doc in res: + counts[doc["status"]] = doc["count"] + return counts + + async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]: + """Get all documents by status""" + SQL = SQL_TEMPLATES["get_docs_by_status"] + params = {"workspace": self.db.workspace, "status": status} + res = await self.db.query(SQL, params, True) + # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] + # Converting to be a dict + return { + element["id"]: DocProcessingStatus( + #content_summary=element["content_summary"], + content_summary = "", + content_length=element["CONTENT_LENGTH"], + status=element["STATUS"], + created_at=element["CREATETIME"], + updated_at=element["UPDATETIME"], + chunks_count=-1, + #chunks_count=element["chunks_count"], + ) + for element in res + } + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + return await self.get_docs_by_status(DocStatus.FAILED) + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + return await self.get_docs_by_status(DocStatus.PENDING) + + async def index_done_callback(self): + """Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here""" + logger.info("Doc status had been saved into ORACLE db!") + + async def upsert(self, data: dict[str, dict]): + """Update or insert document status + + Args: + data: Dictionary of document IDs and their status data + """ + SQL = SQL_TEMPLATES["merge_doc_status"] + for k, v in data.items(): + # chunks_count is optional + params = { + "workspace": self.db.workspace, + "id": k, + "content_summary": v["content_summary"], + "content_length": v["content_length"], + "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, + "status": v["status"], + } + await self.db.execute(SQL, params) + return data + + @dataclass class OracleVectorDBStorage(BaseVectorStorage): + # should pass db object to self.db + db: OracleDB = None cosine_better_than_threshold: float = 0.2 def __post_init__(self): @@ -564,13 +702,18 @@ N_T = { TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( - id varchar(256)PRIMARY KEY, + id varchar(256), workspace varchar(1024), doc_name varchar(1024), content CLOB, meta JSON, + content_summary varchar(1024), + content_length NUMBER, + status varchar(256), + chunks_count NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL + updatetime TIMESTAMP DEFAULT NULL, + error varchar(4096) )""" }, "LIGHTRAG_DOC_CHUNKS": { @@ -618,10 +761,16 @@ TABLES = { }, "LIGHTRAG_LLM_CACHE": { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - id varchar(256) PRIMARY KEY, - send clob, - return clob, - model varchar(1024), + id varchar(256) PRIMARY KEY, + workspace varchar(1024), + cache_mode varchar(256), + model_name varchar(256), + original_prompt clob, + return_value clob, + embedding CLOB, + embedding_shape NUMBER, + embedding_min NUMBER, + embedding_max NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL )""" @@ -647,22 +796,70 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + + "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", + + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", + + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", + "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace) values(:id,:content,:workspace) - """, + + "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = :id and a.workspace = :workspace) + WHEN NOT MATCHED THEN + INSERT(id,content,workspace) values(:id,:content,:workspace)""", + "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + USING DUAL + ON (a.id = :check_id) + WHEN NOT MATCHED THEN + INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + + "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a + USING DUAL + ON (a.id = :id) + WHEN NOT MATCHED THEN + INSERT (workspace,id,original_prompt,return_value,cache_mode) + VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode) + WHEN MATCHED THEN UPDATE + SET original_prompt = :original_prompt, + return_value = :return_value, + cache_mode = :cache_mode, + updatetime = SYSDATE""", + + "get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})", + + "get_status_counts": """SELECT status as "status", COUNT(1) as "count" + FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""", + + "get_docs_by_status": """select content_length,status, + TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime + from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""", + + "merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = :id and a.workspace = :workspace) + WHEN NOT MATCHED THEN + INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status) + WHEN MATCHED THEN UPDATE + SET content_summary = :content_summary, + content_length = :content_length, + chunks_count = :chunks_count, + status = :status, + updatetime = SYSDATE""", + # SQL for VectorStorage "entities": """SELECT name as entity_name FROM (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance @@ -714,16 +911,22 @@ SQL_TEMPLATES = { COLUMNS (a.name as source_name,b.name as target_name))""", "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL - ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.name=:name) WHEN NOT MATCHED THEN INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) - values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL - ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cbe49da2..b6d5238e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -79,6 +79,7 @@ Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") +OracleDocStatusStorage = lazy_external_import(".kg.oracle_impl", "OracleDocStatusStorage") MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") @@ -290,6 +291,7 @@ class LightRAG: # kv storage "JsonKVStorage": JsonKVStorage, "OracleKVStorage": OracleKVStorage, + "OracleDocStatusStorage":OracleDocStatusStorage, "MongoKVStorage": MongoKVStorage, "TiDBKVStorage": TiDBKVStorage, # vector storage diff --git a/lightrag/operate.py b/lightrag/operate.py index b2c4d215..45ba9656 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -59,13 +59,15 @@ async def _handle_entity_relation_summary( description: str, global_config: dict, ) -> str: + """Handle entity relation summary + For each entity or relation, input is the combined description of already existing description and new description. + If too long, use LLM to summarize. + """ use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] summary_max_tokens = global_config["entity_summary_to_max_tokens"] - language = global_config["addon_params"].get( - "language", PROMPTS["DEFAULT_LANGUAGE"] - ) + language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"]) tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) if len(tokens) < summary_max_tokens: # No need for summary @@ -139,6 +141,7 @@ async def _merge_nodes_then_upsert( knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): + """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] @@ -319,7 +322,7 @@ async def extract_entities( llm_response_cache.global_config = new_config need_to_restore = True if history_messages: - history = json.dumps(history_messages) + history = json.dumps(history_messages,ensure_ascii=False) _prompt = history + "\n" + input_text else: _prompt = input_text @@ -351,6 +354,11 @@ async def extract_entities( return await use_llm_func(input_text) async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + """"Prpocess a single chunk + Args: + chunk_key_dp (tuple[str, TextChunkSchema]): + ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) + """ nonlocal already_processed, already_entities, already_relations chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] diff --git a/lightrag/utils.py b/lightrag/utils.py index 1f6bf405..56e4191c 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -36,7 +36,7 @@ logger = logging.getLogger("lightrag") def set_logger(log_file: str): logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file) + file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -473,7 +473,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"] llm_model_func = hashing_kv.global_config.get("llm_model_func") current_embedding = await embedding_model_func([prompt]) From d5ae6669eafafa2c3ee890729965f34d09e03eca Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:52:37 +0800 Subject: [PATCH 3/4] support pipeline mode --- examples/lightrag_oracle_demo.py | 46 ++-- lightrag/kg/oracle_impl.py | 224 +++++------------- lightrag/lightrag.py | 392 +++++++++++++++++++++---------- lightrag/operate.py | 27 ++- lightrag/utils.py | 8 +- 5 files changed, 374 insertions(+), 323 deletions(-) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 2cc8e018..8a5439e2 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -87,12 +87,14 @@ async def main(): # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( - working_dir=WORKING_DIR, + # log_level="DEBUG", + working_dir=WORKING_DIR, entity_extract_max_gleaning = 1, - enable_llm_cache=False, - embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, + enable_llm_cache=True, enable_llm_cache_for_entity_extract = True, + embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, + chunk_token_size=CHUNK_TOKEN_SIZE, llm_model_max_token_size = MAX_TOKENS, @@ -106,34 +108,30 @@ async def main(): graph_storage = "OracleGraphStorage", kv_storage = "OracleKVStorage", vector_storage="OracleVectorDBStorage", - doc_status_storage="OracleDocStatusStorage", - addon_params = {"example_number":1, "language":"Simplfied Chinese"}, + addon_params = {"example_number":1, + "language":"Simplfied Chinese", + "entity_types": ["organization", "person", "geo", "event"], + "insert_batch_size":2, + } ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.key_string_value_json_storage_cls.db = oracle_db - rag.vector_db_storage_cls.db = oracle_db - rag.graph_storage_cls.db = oracle_db - rag.doc_status_storage_cls.db = oracle_db - rag.doc_status.db = oracle_db - rag.full_docs.db = oracle_db - rag.text_chunks.db = oracle_db - rag.llm_response_cache.db = oracle_db - rag.key_string_value_json_storage_cls.db = oracle_db - rag.chunks_vdb.db = oracle_db - rag.relationships_vdb.db = oracle_db - rag.entities_vdb.db = oracle_db - rag.graph_storage_cls.db = oracle_db - rag.chunk_entity_relation_graph.db = oracle_db - rag.llm_response_cache.db = oracle_db + rag.set_storage_client(db_client = oracle_db) - rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func - # Extract and Insert into LightRAG storage - with open("./dickens/demo.txt", "r", encoding="utf-8") as f: - await rag.ainsert(f.read()) + with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f: + all_text = f.read() + texts = [x for x in all_text.split("\n") if x] + + # New mode use pipeline + await rag.apipeline_process_documents(texts) + await rag.apipeline_process_chunks() + await rag.apipeline_process_extract_graph() + # Old method use ainsert + #await rag.ainsert(texts) + # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] for mode in modes: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index d464bcc4..c9deed4e 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -12,9 +12,6 @@ from ..base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, - DocStatusStorage, - DocStatus, - DocProcessingStatus, ) import oracledb @@ -156,8 +153,6 @@ class OracleDB: if data is None: await cursor.execute(sql) else: - # print(data) - # print(sql) await cursor.execute(sql, data) await connection.commit() except Exception as e: @@ -175,7 +170,7 @@ class OracleKVStorage(BaseKVStorage): def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config["embedding_batch_num"] + self._max_batch_size = self.global_config.get("embedding_batch_num",10) ################ QUERY METHODS ################ @@ -204,12 +199,11 @@ class OracleKVStorage(BaseKVStorage): array_res = await self.db.query(SQL, params, multirows=True) res = {} for row in array_res: - res[row["id"]] = row + res[row["id"]] = row return res else: return None - - # Query by id + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( @@ -228,8 +222,7 @@ class OracleKVStorage(BaseKVStorage): dict_res[mode] = {} for row in res: dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] - + res = [{k: v} for k, v in dict_res.items()] if res: data = res # [{"data":i} for i in res] # print(data) @@ -237,38 +230,42 @@ class OracleKVStorage(BaseKVStorage): else: return None + async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]: + """Specifically for llm_response_cache.""" + if ids is not None: + SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + else: + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + res = await self.db.query(SQL, params, multirows=True) + if res: + return res + else: + return None + async def filter_keys(self, keys: list[str]) -> set[str]: - """remove duplicated""" + """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) ) params = {"workspace": self.db.workspace} - try: - await self.db.query(SQL, params) - except Exception as e: - logger.error(f"Oracle database error: {e}") - print(SQL) - print(params) res = await self.db.query(SQL, params, multirows=True) - data = None if res: exist_keys = [key["id"] for key in res] data = set([s for s in keys if s not in exist_keys]) + return data else: - exist_keys = [] - data = set([s for s in keys if s not in exist_keys]) - return data + return set(keys) + ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - # print(self._data) - # values = [] if self.namespace == "text_chunks": list_data = [ { - "__id__": k, + "id": k, **{k1: v1 for k1, v1 in v.items()}, } for k, v in data.items() @@ -284,33 +281,30 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - # print(list_data) + + merge_sql = SQL_TEMPLATES["merge_chunk"] for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"] - data = { - "check_id": item["__id__"], - "id": item["__id__"], + _data = { + "id": item["id"], "content": item["content"], "workspace": self.db.workspace, "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], "full_doc_id": item["full_doc_id"], "content_vector": item["__vector__"], + "status": item["status"], } - # print(merge_sql) - await self.db.execute(merge_sql, data) - + await self.db.execute(merge_sql, _data) if self.namespace == "full_docs": - for k, v in self._data.items(): + for k, v in data.items(): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] - data = { + _data = { "id": k, "content": v["content"], "workspace": self.db.workspace, } - # print(merge_sql) - await self.db.execute(merge_sql, data) + await self.db.execute(merge_sql, _data) if self.namespace == "llm_response_cache": for mode, items in data.items(): @@ -325,102 +319,20 @@ class OracleKVStorage(BaseKVStorage): } await self.db.execute(upsert_sql, _data) - return left_data + return None + + async def change_status(self, id: str, status: str): + SQL = SQL_TEMPLATES["change_status"].format( + table_name=N_T[self.namespace] + ) + params = {"workspace": self.db.workspace, "id": id, "status": status} + await self.db.execute(SQL, params) async def index_done_callback(self): if self.namespace in ["full_docs", "text_chunks"]: logger.info("full doc and chunk data had been saved into oracle db!") -@dataclass -class OracleDocStatusStorage(DocStatusStorage): - """Oracle implementation of document status storage""" - # should pass db object to self.db - db: OracleDB = None - meta_fields = None - - def __post_init__(self): - pass - - async def filter_keys(self, ids: list[str]) -> set[str]: - """Return keys that don't exist in storage""" - SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( - ids = ",".join([f"'{id}'" for id in ids]) - ) - params = {"workspace": self.db.workspace} - res = await self.db.query(SQL, params, True) - # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. - if res: - existed = set([element["id"] for element in res]) - return set(ids) - existed - else: - return set(ids) - - async def get_status_counts(self) -> Dict[str, int]: - """Get counts of documents in each status""" - SQL = SQL_TEMPLATES["get_status_counts"] - params = {"workspace": self.db.workspace} - res = await self.db.query(SQL, params, True) - # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] - counts = {} - for doc in res: - counts[doc["status"]] = doc["count"] - return counts - - async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]: - """Get all documents by status""" - SQL = SQL_TEMPLATES["get_docs_by_status"] - params = {"workspace": self.db.workspace, "status": status} - res = await self.db.query(SQL, params, True) - # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] - # Converting to be a dict - return { - element["id"]: DocProcessingStatus( - #content_summary=element["content_summary"], - content_summary = "", - content_length=element["CONTENT_LENGTH"], - status=element["STATUS"], - created_at=element["CREATETIME"], - updated_at=element["UPDATETIME"], - chunks_count=-1, - #chunks_count=element["chunks_count"], - ) - for element in res - } - - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all failed documents""" - return await self.get_docs_by_status(DocStatus.FAILED) - - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: - """Get all pending documents""" - return await self.get_docs_by_status(DocStatus.PENDING) - - async def index_done_callback(self): - """Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here""" - logger.info("Doc status had been saved into ORACLE db!") - - async def upsert(self, data: dict[str, dict]): - """Update or insert document status - - Args: - data: Dictionary of document IDs and their status data - """ - SQL = SQL_TEMPLATES["merge_doc_status"] - for k, v in data.items(): - # chunks_count is optional - params = { - "workspace": self.db.workspace, - "id": k, - "content_summary": v["content_summary"], - "content_length": v["content_length"], - "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, - "status": v["status"], - } - await self.db.execute(SQL, params) - return data - - @dataclass class OracleVectorDBStorage(BaseVectorStorage): # should pass db object to self.db @@ -466,7 +378,7 @@ class OracleGraphStorage(BaseGraphStorage): def __post_init__(self): """从graphml文件加载图""" - self._max_batch_size = self.global_config["embedding_batch_num"] + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) #################### insert method ################ @@ -500,7 +412,6 @@ class OracleGraphStorage(BaseGraphStorage): "content": content, "content_vector": content_vector, } - # print(merge_sql) await self.db.execute(merge_sql, data) # self._graph.add_node(node_id, **node_data) @@ -718,9 +629,10 @@ TABLES = { }, "LIGHTRAG_DOC_CHUNKS": { "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - id varchar(256) PRIMARY KEY, + id varchar(256), workspace varchar(1024), full_doc_id varchar(256), + status varchar(256), chunk_order_index NUMBER, tokens NUMBER, content CLOB, @@ -795,9 +707,9 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage - "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", - "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", @@ -808,24 +720,34 @@ SQL_TEMPLATES = { "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", - "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", + "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + + "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", + "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", + + "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", + + "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", + "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", + "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", + "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = :id and a.workspace = :workspace) WHEN NOT MATCHED THEN INSERT(id,content,workspace) values(:id,:content,:workspace)""", - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a + "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + ON (id = :id and workspace = :workspace) + WHEN NOT MATCHED THEN INSERT + (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a USING DUAL @@ -838,27 +760,7 @@ SQL_TEMPLATES = { return_value = :return_value, cache_mode = :cache_mode, updatetime = SYSDATE""", - - "get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})", - "get_status_counts": """SELECT status as "status", COUNT(1) as "count" - FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""", - - "get_docs_by_status": """select content_length,status, - TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime - from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""", - - "merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a - USING DUAL - ON (a.id = :id and a.workspace = :workspace) - WHEN NOT MATCHED THEN - INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status) - WHEN MATCHED THEN UPDATE - SET content_summary = :content_summary, - content_length = :content_length, - chunks_count = :chunks_count, - status = :status, - updatetime = SYSDATE""", # SQL for VectorStorage "entities": """SELECT name as entity_name FROM diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6e2f1c0e..7d8cdf45 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -26,6 +26,7 @@ from .utils import ( convert_response_to_json, logger, set_logger, + statistic_data ) from .base import ( BaseGraphStorage, @@ -36,22 +37,31 @@ from .base import ( DocStatus, ) -from .storage import ( - JsonKVStorage, - NanoVectorDBStorage, - NetworkXStorage, - JsonDocStatusStorage, -) - from .prompt import GRAPH_FIELD_SEP +STORAGES = { + "JsonKVStorage": '.storage', + "NanoVectorDBStorage": '.storage', + "NetworkXStorage": '.storage', + "JsonDocStatusStorage": '.storage', -# future KG integrations - -# from .kg.ArangoDB_impl import ( -# GraphStorage as ArangoDBStorage -# ) - + "Neo4JStorage":".kg.neo4j_impl", + "OracleKVStorage":".kg.oracle_impl", + "OracleGraphStorage":".kg.oracle_impl", + "OracleVectorDBStorage":".kg.oracle_impl", + "MilvusVectorDBStorge":".kg.milvus_impl", + "MongoKVStorage":".kg.mongo_impl", + "ChromaVectorDBStorage":".kg.chroma_impl", + "TiDBKVStorage":".kg.tidb_impl", + "TiDBVectorDBStorage":".kg.tidb_impl", + "TiDBGraphStorage":".kg.tidb_impl", + "PGKVStorage":".kg.postgres_impl", + "PGVectorStorage":".kg.postgres_impl", + "AGEStorage":".kg.age_impl", + "PGGraphStorage":".kg.postgres_impl", + "GremlinStorage":".kg.gremlin_impl", + "PGDocStatusStorage":".kg.postgres_impl", +} def lazy_external_import(module_name: str, class_name: str): """Lazily import a class from an external module based on the package of the caller.""" @@ -65,36 +75,13 @@ def lazy_external_import(module_name: str, class_name: str): def import_class(*args, **kwargs): import importlib - - # Import the module using importlib module = importlib.import_module(module_name, package=package) - - # Get the class from the module and instantiate it cls = getattr(module, class_name) return cls(*args, **kwargs) return import_class -Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") -OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") -OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") -OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") -OracleDocStatusStorage = lazy_external_import(".kg.oracle_impl", "OracleDocStatusStorage") -MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") -MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") -ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") -TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") -TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") -TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage") -PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage") -PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage") -AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage") -PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage") -GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage") -PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage") - - def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ Ensure that there is always an event loop available. @@ -198,52 +185,64 @@ class LightRAG: logger.setLevel(self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") - - _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) - logger.debug(f"LightRAG init with param:\n {_print_config}\n") - - # @TODO: should move all storage setup here to leverage initial start params attached to self. - - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( - self._get_storage_class()[self.kv_storage] - ) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ - self.vector_storage - ] - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ - self.graph_storage - ] - if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) - self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - global_config=asdict(self), - embedding_func=None, - ) + # show config + global_config=asdict(self) + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) + logger.debug(f"LightRAG init with param:\n {_print_config}\n") + # Init LLM self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) + + + # Initialize all storages + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage) + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage) + + self.key_string_value_json_storage_cls = partial( + self.key_string_value_json_storage_cls, + global_config=global_config + ) + + self.vector_db_storage_cls = partial( + self.vector_db_storage_cls, + global_config=global_config + ) + + self.graph_storage_cls = partial( + self.graph_storage_cls, + global_config=global_config + ) + + self.json_doc_status_storage = self.key_string_value_json_storage_cls( + namespace="json_doc_status_storage", + embedding_func=None, + ) + + self.llm_response_cache = self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + embedding_func=None, + ) #### # add embedding func by walter #### self.full_docs = self.key_string_value_json_storage_cls( namespace="full_docs", - global_config=asdict(self), embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( namespace="text_chunks", - global_config=asdict(self), embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( namespace="chunk_entity_relation", - global_config=asdict(self), embedding_func=self.embedding_func, ) #### @@ -252,73 +251,64 @@ class LightRAG: self.entities_vdb = self.vector_db_storage_cls( namespace="entities", - global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) self.relationships_vdb = self.vector_db_storage_cls( namespace="relationships", - global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) self.chunks_vdb = self.vector_db_storage_cls( namespace="chunks", - global_config=asdict(self), embedding_func=self.embedding_func, ) + if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"): + hashing_kv = self.llm_response_cache + else: + hashing_kv = self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + embedding_func=None, + ) + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - global_config=asdict(self), - embedding_func=None, - ), + hashing_kv=hashing_kv, **self.llm_model_kwargs, ) ) # Initialize document status storage - self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage] + self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status = self.doc_status_storage_cls( namespace="doc_status", - global_config=asdict(self), + global_config=global_config, embedding_func=None, ) - def _get_storage_class(self) -> dict: - return { - # kv storage - "JsonKVStorage": JsonKVStorage, - "OracleKVStorage": OracleKVStorage, - "OracleDocStatusStorage":OracleDocStatusStorage, - "MongoKVStorage": MongoKVStorage, - "TiDBKVStorage": TiDBKVStorage, - # vector storage - "NanoVectorDBStorage": NanoVectorDBStorage, - "OracleVectorDBStorage": OracleVectorDBStorage, - "MilvusVectorDBStorge": MilvusVectorDBStorge, - "ChromaVectorDBStorage": ChromaVectorDBStorage, - "TiDBVectorDBStorage": TiDBVectorDBStorage, - # graph storage - "NetworkXStorage": NetworkXStorage, - "Neo4JStorage": Neo4JStorage, - "OracleGraphStorage": OracleGraphStorage, - "AGEStorage": AGEStorage, - "PGGraphStorage": PGGraphStorage, - "PGKVStorage": PGKVStorage, - "PGDocStatusStorage": PGDocStatusStorage, - "PGVectorStorage": PGVectorStorage, - "TiDBGraphStorage": TiDBGraphStorage, - "GremlinStorage": GremlinStorage, - # "ArangoDBStorage": ArangoDBStorage - "JsonDocStatusStorage": JsonDocStatusStorage, - } + def _get_storage_class(self, storage_name: str) -> dict: + import_path = STORAGES[storage_name] + storage_class = lazy_external_import(import_path, storage_name) + return storage_class + + def set_storage_client(self,db_client): + # Now only tested on Oracle Database + for storage in [self.vector_db_storage_cls, + self.graph_storage_cls, + self.doc_status, self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.key_string_value_json_storage_cls, + self.chunks_vdb, + self.relationships_vdb, + self.entities_vdb, + self.graph_storage_cls, + self.chunk_entity_relation_graph, + self.llm_response_cache]: + # set client + storage.db = db_client def insert( self, string_or_strings, split_by_character=None, split_by_character_only=False @@ -358,6 +348,11 @@ class LightRAG: } for content in unique_contents } + + # 3. Store original document and chunks + await self.full_docs.upsert( + {doc_id: {"content": doc["content"]}} + ) # 3. Filter out already processed documents _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) @@ -406,12 +401,7 @@ class LightRAG: } # Update status with chunks information - doc_status.update( - { - "chunks_count": len(chunks), - "updated_at": datetime.now().isoformat(), - } - ) + doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()}) await self.doc_status.upsert({doc_id: doc_status}) try: @@ -435,30 +425,16 @@ class LightRAG: self.chunk_entity_relation_graph = maybe_new_kg - # Store original document and chunks - await self.full_docs.upsert( - {doc_id: {"content": doc["content"]}} - ) + await self.text_chunks.upsert(chunks) # Update status to processed - doc_status.update( - { - "status": DocStatus.PROCESSED, - "updated_at": datetime.now().isoformat(), - } - ) + doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()}) await self.doc_status.upsert({doc_id: doc_status}) except Exception as e: # Mark as failed if any step fails - doc_status.update( - { - "status": DocStatus.FAILED, - "error": str(e), - "updated_at": datetime.now().isoformat(), - } - ) + doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()}) await self.doc_status.upsert({doc_id: doc_status}) raise e @@ -540,6 +516,174 @@ class LightRAG: if update_storage: await self._insert_done() + async def apipeline_process_documents(self, string_or_strings): + """Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs + Args: + string_or_strings: Single document string or list of document strings + """ + if isinstance(string_or_strings, str): + string_or_strings = [string_or_strings] + + # 1. Remove duplicate contents from the list + unique_contents = list(set(doc.strip() for doc in string_or_strings)) + + logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents") + + # 2. Generate document IDs and initial status + new_docs = { + compute_mdhash_id(content, prefix="doc-"): { + "content": content, + "content_summary": self._get_content_summary(content), + "content_length": len(content), + "status": DocStatus.PENDING, + "created_at": datetime.now().isoformat(), + "updated_at": None, + } + for content in unique_contents + } + + # 3. Filter out already processed documents + _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) + if len(_not_stored_doc_keys) < len(new_docs): + logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents") + new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} + + if not new_docs: + logger.info(f"All documents have been processed or are duplicates") + return None + + # 4. Store original document + for doc_id, doc in new_docs.items(): + await self.full_docs.upsert({doc_id: {"content": doc["content"]}}) + await self.full_docs.change_status(doc_id, DocStatus.PENDING) + logger.info(f"Stored {len(new_docs)} new unique documents") + + async def apipeline_process_chunks(self): + """Get pendding documents, split into chunks,insert chunks""" + # 1. get all pending and failed documents + _todo_doc_keys = [] + _failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) + _pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + if _failed_doc: + _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) + if _pendding_doc: + _todo_doc_keys.extend([doc["id"] for doc in _pendding_doc]) + if not _todo_doc_keys: + logger.info("All documents have been processed or are duplicates") + return None + else: + logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") + + new_docs = { + doc["id"]: doc + for doc in await self.full_docs.get_by_ids(_todo_doc_keys) + } + + # 2. split docs into chunks, insert chunks, update doc status + chunk_cnt = 0 + batch_size = self.addon_params.get("insert_batch_size", 10) + for i in range(0, len(new_docs), batch_size): + batch_docs = dict(list(new_docs.items())[i : i + batch_size]) + for doc_id, doc in tqdm_async( + batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}" + ): + try: + # Generate chunks from document + chunks = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, + "status": DocStatus.PENDING, + } + for dp in chunking_by_token_size( + doc["content"], + overlap_token_size=self.chunk_overlap_token_size, + max_token_size=self.chunk_token_size, + tiktoken_model=self.tiktoken_model_name, + ) + } + chunk_cnt += len(chunks) + await self.text_chunks.upsert(chunks) + await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED) + + try: + # Store chunks in vector database + await self.chunks_vdb.upsert(chunks) + # Update doc status + await self.full_docs.change_status(doc_id, DocStatus.PROCESSED) + except Exception as e: + # Mark as failed if any step fails + await self.full_docs.change_status(doc_id, DocStatus.FAILED) + raise e + except Exception as e: + import traceback + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") + + async def apipeline_process_extract_graph(self): + """Get pendding or failed chunks, extract entities and relationships from each chunk""" + # 1. get all pending and failed chunks + _todo_chunk_keys = [] + _failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) + _pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + if _failed_chunks: + _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) + if _pendding_chunks: + _todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks]) + if not _todo_chunk_keys: + logger.info("All chunks have been processed or are duplicates") + return None + + # Process documents in batches + batch_size = self.addon_params.get("insert_batch_size", 10) + + semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously + + async def process_chunk(chunk_id): + async with semaphore: + chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])} + # Extract and store entities and relationships + try: + maybe_new_kg = await extract_entities( + chunks, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + llm_response_cache=self.llm_response_cache, + global_config=asdict(self), + ) + if maybe_new_kg is None: + logger.info("No entities or relationships extracted!") + # Update status to processed + await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED) + except Exception as e: + logger.error("Failed to extract entities and relationships") + # Mark as failed if any step fails + await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) + raise e + + with tqdm_async(total=len(_todo_chunk_keys), + desc="\nLevel 1 - Processing chunks", + unit="chunk", + position=0) as progress: + tasks = [] + for chunk_id in _todo_chunk_keys: + task = asyncio.create_task(process_chunk(chunk_id)) + tasks.append(task) + + for future in asyncio.as_completed(tasks): + await future + progress.update(1) + progress.set_postfix({ + 'LLM call': statistic_data["llm_call"], + 'LLM cache': statistic_data["llm_cache"], + }) + + # Ensure all indexes are updated after each document + await self._insert_done() + async def _insert_done(self): tasks = [] for storage_inst in [ diff --git a/lightrag/operate.py b/lightrag/operate.py index 97ac245c..f9e48dbf 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -20,6 +20,7 @@ from .utils import ( handle_cache, save_to_cache, CacheData, + statistic_data ) from .base import ( BaseGraphStorage, @@ -371,14 +372,16 @@ async def extract_entities( if need_to_restore: llm_response_cache.global_config = global_config if cached_return: + logger.debug(f"Found cache for {arg_hash}") + statistic_data["llm_cache"] += 1 return cached_return - + statistic_data["llm_call"] += 1 if history_messages: res: str = await use_llm_func( input_text, history_messages=history_messages ) else: - res: str = await use_llm_func(input_text) + res: str = await use_llm_func(input_text) await save_to_cache( llm_response_cache, CacheData(args_hash=arg_hash, content=res, prompt=_prompt), @@ -459,10 +462,8 @@ async def extract_entities( now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) ] - print( + logger.debug( f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", - end="", - flush=True, ) return dict(maybe_nodes), dict(maybe_edges) @@ -470,8 +471,8 @@ async def extract_entities( for result in tqdm_async( asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), total=len(ordered_chunks), - desc="Extracting entities from chunks", - unit="chunk", + desc="Level 2 - Extracting entities and relationships", + unit="chunk", position=1,leave=False ): results.append(await result) @@ -482,7 +483,7 @@ async def extract_entities( maybe_nodes[k].extend(v) for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) - logger.info("Inserting entities into storage...") + logger.debug("Inserting entities into storage...") all_entities_data = [] for result in tqdm_async( asyncio.as_completed( @@ -492,12 +493,12 @@ async def extract_entities( ] ), total=len(maybe_nodes), - desc="Inserting entities", - unit="entity", + desc="Level 3 - Inserting entities", + unit="entity", position=2,leave=False ): all_entities_data.append(await result) - logger.info("Inserting relationships into storage...") + logger.debug("Inserting relationships into storage...") all_relationships_data = [] for result in tqdm_async( asyncio.as_completed( @@ -509,8 +510,8 @@ async def extract_entities( ] ), total=len(maybe_edges), - desc="Inserting relationships", - unit="relationship", + desc="Level 3 - Inserting relationships", + unit="relationship", position=3,leave=False ): all_relationships_data.append(await result) diff --git a/lightrag/utils.py b/lightrag/utils.py index 56e4191c..a83c0382 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -30,8 +30,13 @@ class UnlimitedSemaphore: ENCODER = None +statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} + logger = logging.getLogger("lightrag") +# Set httpx logging level to WARNING +logging.getLogger("httpx").setLevel(logging.WARNING) + def set_logger(log_file: str): logger.setLevel(logging.DEBUG) @@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): return None, None, None, None # For naive mode, only use simple cache matching - if mode == "naive": + #if mode == "naive": + if mode == "default": if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: From 6ae8647285abcdfb6875f667c5c5daa1c7409fa8 Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:58:15 +0800 Subject: [PATCH 4/4] support pipeline mode --- .gitignore | 2 +- examples/lightrag_oracle_demo.py | 44 +++--- lightrag/kg/oracle_impl.py | 52 +++---- lightrag/lightrag.py | 245 ++++++++++++++++++------------- lightrag/operate.py | 24 ++- lightrag/utils.py | 8 +- 6 files changed, 203 insertions(+), 172 deletions(-) diff --git a/.gitignore b/.gitignore index 0e0ec299..ec95f8a5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,4 @@ rag_storage venv/ examples/input/ examples/output/ -.DS_Store \ No newline at end of file +.DS_Store diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 8a5439e2..6de6e0a7 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -89,49 +89,45 @@ async def main(): rag = LightRAG( # log_level="DEBUG", working_dir=WORKING_DIR, - entity_extract_max_gleaning = 1, - + entity_extract_max_gleaning=1, enable_llm_cache=True, - enable_llm_cache_for_entity_extract = True, - embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, - - + enable_llm_cache_for_entity_extract=True, + embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90}, chunk_token_size=CHUNK_TOKEN_SIZE, - llm_model_max_token_size = MAX_TOKENS, + llm_model_max_token_size=MAX_TOKENS, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dimension, max_token_size=500, func=embedding_func, - ), - - graph_storage = "OracleGraphStorage", - kv_storage = "OracleKVStorage", + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", vector_storage="OracleVectorDBStorage", - - addon_params = {"example_number":1, - "language":"Simplfied Chinese", - "entity_types": ["organization", "person", "geo", "event"], - "insert_batch_size":2, - } + addon_params={ + "example_number": 1, + "language": "Simplfied Chinese", + "entity_types": ["organization", "person", "geo", "event"], + "insert_batch_size": 2, + }, ) - # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.set_storage_client(db_client = oracle_db) + # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool + rag.set_storage_client(db_client=oracle_db) # Extract and Insert into LightRAG storage - with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f: + with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: all_text = f.read() texts = [x for x in all_text.split("\n") if x] - + # New mode use pipeline await rag.apipeline_process_documents(texts) - await rag.apipeline_process_chunks() + await rag.apipeline_process_chunks() await rag.apipeline_process_extract_graph() # Old method use ainsert - #await rag.ainsert(texts) - + # await rag.ainsert(texts) + # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] for mode in modes: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index c9deed4e..e30b6909 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -3,7 +3,7 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Union, List, Dict, Set, Any, Tuple +from typing import Union import numpy as np import array @@ -170,7 +170,7 @@ class OracleKVStorage(BaseKVStorage): def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config.get("embedding_batch_num",10) + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) ################ QUERY METHODS ################ @@ -190,7 +190,7 @@ class OracleKVStorage(BaseKVStorage): return res else: return None - + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] @@ -199,11 +199,11 @@ class OracleKVStorage(BaseKVStorage): array_res = await self.db.query(SQL, params, multirows=True) res = {} for row in array_res: - res[row["id"]] = row + res[row["id"]] = row return res else: return None - + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( @@ -222,7 +222,7 @@ class OracleKVStorage(BaseKVStorage): dict_res[mode] = {} for row in res: dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] + res = [{k: v} for k, v in dict_res.items()] if res: data = res # [{"data":i} for i in res] # print(data) @@ -230,7 +230,9 @@ class OracleKVStorage(BaseKVStorage): else: return None - async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]: + async def get_by_status_and_ids( + self, status: str, ids: list[str] + ) -> Union[list[dict], None]: """Specifically for llm_response_cache.""" if ids is not None: SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( @@ -244,7 +246,7 @@ class OracleKVStorage(BaseKVStorage): return res else: return None - + async def filter_keys(self, keys: list[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( @@ -258,7 +260,6 @@ class OracleKVStorage(BaseKVStorage): return data else: return set(keys) - ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): @@ -281,7 +282,7 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - + merge_sql = SQL_TEMPLATES["merge_chunk"] for item in list_data: _data = { @@ -320,11 +321,9 @@ class OracleKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) return None - + async def change_status(self, id: str, status: str): - SQL = SQL_TEMPLATES["change_status"].format( - table_name=N_T[self.namespace] - ) + SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace]) params = {"workspace": self.db.workspace, "id": id, "status": status} await self.db.execute(SQL, params) @@ -673,8 +672,8 @@ TABLES = { }, "LIGHTRAG_LLM_CACHE": { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - id varchar(256) PRIMARY KEY, - workspace varchar(1024), + id varchar(256) PRIMARY KEY, + workspace varchar(1024), cache_mode varchar(256), model_name varchar(256), original_prompt clob, @@ -708,47 +707,32 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", - "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", - "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", - "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", - "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", - "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", - "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", - "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", - "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", - "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = :id and a.workspace = :workspace) WHEN NOT MATCHED THEN INSERT(id,content,workspace) values(:id,:content,:workspace)""", - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS USING DUAL ON (id = :id and workspace = :workspace) WHEN NOT MATCHED THEN INSERT (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, - "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a USING DUAL ON (a.id = :id) @@ -760,8 +744,6 @@ SQL_TEMPLATES = { return_value = :return_value, cache_mode = :cache_mode, updatetime = SYSDATE""", - - # SQL for VectorStorage "entities": """SELECT name as entity_name FROM (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance @@ -818,7 +800,7 @@ SQL_TEMPLATES = { INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) WHEN MATCHED THEN - UPDATE SET + UPDATE SET entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL @@ -827,7 +809,7 @@ SQL_TEMPLATES = { INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) WHEN MATCHED THEN - UPDATE SET + UPDATE SET weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7d8cdf45..0902fc50 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -26,7 +26,7 @@ from .utils import ( convert_response_to_json, logger, set_logger, - statistic_data + statistic_data, ) from .base import ( BaseGraphStorage, @@ -39,30 +39,30 @@ from .base import ( from .prompt import GRAPH_FIELD_SEP -STORAGES = { - "JsonKVStorage": '.storage', - "NanoVectorDBStorage": '.storage', - "NetworkXStorage": '.storage', - "JsonDocStatusStorage": '.storage', - - "Neo4JStorage":".kg.neo4j_impl", - "OracleKVStorage":".kg.oracle_impl", - "OracleGraphStorage":".kg.oracle_impl", - "OracleVectorDBStorage":".kg.oracle_impl", - "MilvusVectorDBStorge":".kg.milvus_impl", - "MongoKVStorage":".kg.mongo_impl", - "ChromaVectorDBStorage":".kg.chroma_impl", - "TiDBKVStorage":".kg.tidb_impl", - "TiDBVectorDBStorage":".kg.tidb_impl", - "TiDBGraphStorage":".kg.tidb_impl", - "PGKVStorage":".kg.postgres_impl", - "PGVectorStorage":".kg.postgres_impl", - "AGEStorage":".kg.age_impl", - "PGGraphStorage":".kg.postgres_impl", - "GremlinStorage":".kg.gremlin_impl", - "PGDocStatusStorage":".kg.postgres_impl", +STORAGES = { + "JsonKVStorage": ".storage", + "NanoVectorDBStorage": ".storage", + "NetworkXStorage": ".storage", + "JsonDocStatusStorage": ".storage", + "Neo4JStorage": ".kg.neo4j_impl", + "OracleKVStorage": ".kg.oracle_impl", + "OracleGraphStorage": ".kg.oracle_impl", + "OracleVectorDBStorage": ".kg.oracle_impl", + "MilvusVectorDBStorge": ".kg.milvus_impl", + "MongoKVStorage": ".kg.mongo_impl", + "ChromaVectorDBStorage": ".kg.chroma_impl", + "TiDBKVStorage": ".kg.tidb_impl", + "TiDBVectorDBStorage": ".kg.tidb_impl", + "TiDBGraphStorage": ".kg.tidb_impl", + "PGKVStorage": ".kg.postgres_impl", + "PGVectorStorage": ".kg.postgres_impl", + "AGEStorage": ".kg.age_impl", + "PGGraphStorage": ".kg.postgres_impl", + "GremlinStorage": ".kg.gremlin_impl", + "PGDocStatusStorage": ".kg.postgres_impl", } + def lazy_external_import(module_name: str, class_name: str): """Lazily import a class from an external module based on the package of the caller.""" @@ -75,6 +75,7 @@ def lazy_external_import(module_name: str, class_name: str): def import_class(*args, **kwargs): import importlib + module = importlib.import_module(module_name, package=package) cls = getattr(module, class_name) return cls(*args, **kwargs) @@ -190,7 +191,7 @@ class LightRAG: os.makedirs(self.working_dir) # show config - global_config=asdict(self) + global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -198,31 +199,33 @@ class LightRAG: self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) - # Initialize all storages - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage) - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage) - + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self._get_storage_class(self.kv_storage) + ) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( + self.vector_storage + ) + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( + self.graph_storage + ) + self.key_string_value_json_storage_cls = partial( - self.key_string_value_json_storage_cls, - global_config=global_config + self.key_string_value_json_storage_cls, global_config=global_config ) self.vector_db_storage_cls = partial( - self.vector_db_storage_cls, - global_config=global_config + self.vector_db_storage_cls, global_config=global_config ) self.graph_storage_cls = partial( - self.graph_storage_cls, - global_config=global_config + self.graph_storage_cls, global_config=global_config ) self.json_doc_status_storage = self.key_string_value_json_storage_cls( namespace="json_doc_status_storage", - embedding_func=None, + embedding_func=None, ) self.llm_response_cache = self.key_string_value_json_storage_cls( @@ -264,13 +267,15 @@ class LightRAG: embedding_func=self.embedding_func, ) - if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"): + if self.llm_response_cache and hasattr( + self.llm_response_cache, "global_config" + ): hashing_kv = self.llm_response_cache else: hashing_kv = self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - embedding_func=None, - ) + namespace="llm_response_cache", + embedding_func=None, + ) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -292,21 +297,24 @@ class LightRAG: import_path = STORAGES[storage_name] storage_class = lazy_external_import(import_path, storage_name) return storage_class - - def set_storage_client(self,db_client): + + def set_storage_client(self, db_client): # Now only tested on Oracle Database - for storage in [self.vector_db_storage_cls, - self.graph_storage_cls, - self.doc_status, self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.key_string_value_json_storage_cls, - self.chunks_vdb, - self.relationships_vdb, - self.entities_vdb, - self.graph_storage_cls, - self.chunk_entity_relation_graph, - self.llm_response_cache]: + for storage in [ + self.vector_db_storage_cls, + self.graph_storage_cls, + self.doc_status, + self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.key_string_value_json_storage_cls, + self.chunks_vdb, + self.relationships_vdb, + self.entities_vdb, + self.graph_storage_cls, + self.chunk_entity_relation_graph, + self.llm_response_cache, + ]: # set client storage.db = db_client @@ -348,11 +356,6 @@ class LightRAG: } for content in unique_contents } - - # 3. Store original document and chunks - await self.full_docs.upsert( - {doc_id: {"content": doc["content"]}} - ) # 3. Filter out already processed documents _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) @@ -401,7 +404,12 @@ class LightRAG: } # Update status with chunks information - doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "chunks_count": len(chunks), + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) try: @@ -425,16 +433,30 @@ class LightRAG: self.chunk_entity_relation_graph = maybe_new_kg - + # Store original document and chunks + await self.full_docs.upsert( + {doc_id: {"content": doc["content"]}} + ) await self.text_chunks.upsert(chunks) # Update status to processed - doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "status": DocStatus.PROCESSED, + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) except Exception as e: # Mark as failed if any step fails - doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "status": DocStatus.FAILED, + "error": str(e), + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) raise e @@ -527,7 +549,9 @@ class LightRAG: # 1. Remove duplicate contents from the list unique_contents = list(set(doc.strip() for doc in string_or_strings)) - logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents") + logger.info( + f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents" + ) # 2. Generate document IDs and initial status new_docs = { @@ -542,28 +566,34 @@ class LightRAG: for content in unique_contents } - # 3. Filter out already processed documents + # 3. Filter out already processed documents _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) if len(_not_stored_doc_keys) < len(new_docs): - logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents") + logger.info( + f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents" + ) new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} if not new_docs: - logger.info(f"All documents have been processed or are duplicates") + logger.info("All documents have been processed or are duplicates") return None - # 4. Store original document + # 4. Store original document for doc_id, doc in new_docs.items(): await self.full_docs.upsert({doc_id: {"content": doc["content"]}}) await self.full_docs.change_status(doc_id, DocStatus.PENDING) logger.info(f"Stored {len(new_docs)} new unique documents") - + async def apipeline_process_chunks(self): - """Get pendding documents, split into chunks,insert chunks""" - # 1. get all pending and failed documents + """Get pendding documents, split into chunks,insert chunks""" + # 1. get all pending and failed documents _todo_doc_keys = [] - _failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) - _pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + _failed_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) if _failed_doc: _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) if _pendding_doc: @@ -573,10 +603,9 @@ class LightRAG: return None else: logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") - + new_docs = { - doc["id"]: doc - for doc in await self.full_docs.get_by_ids(_todo_doc_keys) + doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys) } # 2. split docs into chunks, insert chunks, update doc status @@ -585,8 +614,9 @@ class LightRAG: for i in range(0, len(new_docs), batch_size): batch_docs = dict(list(new_docs.items())[i : i + batch_size]) for doc_id, doc in tqdm_async( - batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}" - ): + batch_docs.items(), + desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}", + ): try: # Generate chunks from document chunks = { @@ -616,18 +646,23 @@ class LightRAG: await self.full_docs.change_status(doc_id, DocStatus.FAILED) raise e except Exception as e: - import traceback - error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" - logger.error(error_msg) - continue - logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") + import traceback + + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") async def apipeline_process_extract_graph(self): """Get pendding or failed chunks, extract entities and relationships from each chunk""" - # 1. get all pending and failed chunks + # 1. get all pending and failed chunks _todo_chunk_keys = [] - _failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) - _pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + _failed_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) if _failed_chunks: _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) if _pendding_chunks: @@ -635,15 +670,19 @@ class LightRAG: if not _todo_chunk_keys: logger.info("All chunks have been processed or are duplicates") return None - + # Process documents in batches batch_size = self.addon_params.get("insert_batch_size", 10) - semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously + semaphore = asyncio.Semaphore( + batch_size + ) # Control the number of tasks that are processed simultaneously - async def process_chunk(chunk_id): + async def process_chunk(chunk_id): async with semaphore: - chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])} + chunks = { + i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) + } # Extract and store entities and relationships try: maybe_new_kg = await extract_entities( @@ -662,25 +701,29 @@ class LightRAG: logger.error("Failed to extract entities and relationships") # Mark as failed if any step fails await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) - raise e + raise e - with tqdm_async(total=len(_todo_chunk_keys), - desc="\nLevel 1 - Processing chunks", - unit="chunk", - position=0) as progress: + with tqdm_async( + total=len(_todo_chunk_keys), + desc="\nLevel 1 - Processing chunks", + unit="chunk", + position=0, + ) as progress: tasks = [] for chunk_id in _todo_chunk_keys: task = asyncio.create_task(process_chunk(chunk_id)) tasks.append(task) - + for future in asyncio.as_completed(tasks): await future progress.update(1) - progress.set_postfix({ - 'LLM call': statistic_data["llm_call"], - 'LLM cache': statistic_data["llm_cache"], - }) - + progress.set_postfix( + { + "LLM call": statistic_data["llm_call"], + "LLM cache": statistic_data["llm_cache"], + } + ) + # Ensure all indexes are updated after each document await self._insert_done() diff --git a/lightrag/operate.py b/lightrag/operate.py index f9e48dbf..fbfd1fbd 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -20,7 +20,7 @@ from .utils import ( handle_cache, save_to_cache, CacheData, - statistic_data + statistic_data, ) from .base import ( BaseGraphStorage, @@ -105,7 +105,9 @@ async def _handle_entity_relation_summary( llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] summary_max_tokens = global_config["entity_summary_to_max_tokens"] - language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) if len(tokens) < summary_max_tokens: # No need for summary @@ -360,7 +362,7 @@ async def extract_entities( llm_response_cache.global_config = new_config need_to_restore = True if history_messages: - history = json.dumps(history_messages,ensure_ascii=False) + history = json.dumps(history_messages, ensure_ascii=False) _prompt = history + "\n" + input_text else: _prompt = input_text @@ -381,7 +383,7 @@ async def extract_entities( input_text, history_messages=history_messages ) else: - res: str = await use_llm_func(input_text) + res: str = await use_llm_func(input_text) await save_to_cache( llm_response_cache, CacheData(args_hash=arg_hash, content=res, prompt=_prompt), @@ -394,7 +396,7 @@ async def extract_entities( return await use_llm_func(input_text) async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): - """"Prpocess a single chunk + """ "Prpocess a single chunk Args: chunk_key_dp (tuple[str, TextChunkSchema]): ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) @@ -472,7 +474,9 @@ async def extract_entities( asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), total=len(ordered_chunks), desc="Level 2 - Extracting entities and relationships", - unit="chunk", position=1,leave=False + unit="chunk", + position=1, + leave=False, ): results.append(await result) @@ -494,7 +498,9 @@ async def extract_entities( ), total=len(maybe_nodes), desc="Level 3 - Inserting entities", - unit="entity", position=2,leave=False + unit="entity", + position=2, + leave=False, ): all_entities_data.append(await result) @@ -511,7 +517,9 @@ async def extract_entities( ), total=len(maybe_edges), desc="Level 3 - Inserting relationships", - unit="relationship", position=3,leave=False + unit="relationship", + position=3, + leave=False, ): all_relationships_data.append(await result) diff --git a/lightrag/utils.py b/lightrag/utils.py index a83c0382..ce556ab2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -41,7 +41,7 @@ logging.getLogger("httpx").setLevel(logging.WARNING) def set_logger(log_file: str): logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -458,7 +458,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): return None, None, None, None # For naive mode, only use simple cache matching - #if mode == "naive": + # if mode == "naive": if mode == "default": if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} @@ -479,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"] + embedding_model_func = hashing_kv.global_config[ + "embedding_func" + ].func # ["func"] llm_model_func = hashing_kv.global_config.get("llm_model_func") current_embedding = await embedding_model_func([prompt])