From 85331e3fa20d755a9bd978f5e7a2604b0bea8d3b Mon Sep 17 00:00:00 2001 From: jin <52519003+jin38324@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:36:28 +0800 Subject: [PATCH] update Oracle support add cache support, fix bug --- examples/lightrag_oracle_demo.py | 45 ++++-- lightrag/kg/oracle_impl.py | 265 +++++++++++++++++++++++++++---- lightrag/lightrag.py | 2 + lightrag/operate.py | 16 +- lightrag/utils.py | 4 +- 5 files changed, 284 insertions(+), 48 deletions(-) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index bbb69319..2cc8e018 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" APIKEY = "ocigenerativeai" CHATMODEL = "cohere.command-r-plus" EMBEDMODEL = "cohere.embed-multilingual-v3.0" - +CHUNK_TOKEN_SIZE = 1024 +MAX_TOKENS = 4000 if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -86,27 +87,49 @@ async def main(): # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( + working_dir=WORKING_DIR, + entity_extract_max_gleaning = 1, + enable_llm_cache=False, - working_dir=WORKING_DIR, - chunk_token_size=512, + embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, + enable_llm_cache_for_entity_extract = True, + + chunk_token_size=CHUNK_TOKEN_SIZE, + llm_model_max_token_size = MAX_TOKENS, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dimension, - max_token_size=512, + max_token_size=500, func=embedding_func, - ), - graph_storage="OracleGraphStorage", - kv_storage="OracleKVStorage", + ), + + graph_storage = "OracleGraphStorage", + kv_storage = "OracleKVStorage", vector_storage="OracleVectorDBStorage", + doc_status_storage="OracleDocStatusStorage", + + addon_params = {"example_number":1, "language":"Simplfied Chinese"}, ) - # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.graph_storage_cls.db = oracle_db + # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.key_string_value_json_storage_cls.db = oracle_db rag.vector_db_storage_cls.db = oracle_db - # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c - rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + rag.graph_storage_cls.db = oracle_db + rag.doc_status_storage_cls.db = oracle_db + rag.doc_status.db = oracle_db + rag.full_docs.db = oracle_db + rag.text_chunks.db = oracle_db + rag.llm_response_cache.db = oracle_db + rag.key_string_value_json_storage_cls.db = oracle_db + rag.chunks_vdb.db = oracle_db + rag.relationships_vdb.db = oracle_db + rag.entities_vdb.db = oracle_db + rag.graph_storage_cls.db = oracle_db + rag.chunk_entity_relation_graph.db = oracle_db + rag.llm_response_cache.db = oracle_db + rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + # Extract and Insert into LightRAG storage with open("./dickens/demo.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 34745312..d464bcc4 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -3,7 +3,7 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Union +from typing import Union, List, Dict, Set, Any, Tuple import numpy as np import array @@ -12,6 +12,9 @@ from ..base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, + DocStatusStorage, + DocStatus, + DocProcessingStatus, ) import oracledb @@ -167,6 +170,9 @@ class OracleDB: @dataclass class OracleKVStorage(BaseKVStorage): # should pass db object to self.db + db: OracleDB = None + meta_fields = None + def __post_init__(self): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] @@ -174,28 +180,56 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> Union[dict, None]: - """根据 id 获取 doc_full 数据.""" + """get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL, params) + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + else: + res = await self.db.query(SQL, params) if res: - data = res # {"data":res} - # print (data) - return data + return res + else: + return None + + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + """Specifically for llm_response_cache.""" + SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res else: return None # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: - """根据 id 获取 doc_chunks 数据""" + """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} # print("get_by_ids:"+SQL) - # print(params) res = await self.db.query(SQL, params, multirows=True) + if "llm_response_cache" == self.namespace: + modes = set() + dict_res: dict[str, dict] = {} + for row in res: + modes.add(row["mode"]) + for mode in modes: + if mode not in dict_res: + dict_res[mode] = {} + for row in res: + dict_res[row["mode"]][row["id"]] = row + res = [{k: v} for k, v in dict_res.items()] + if res: data = res # [{"data":i} for i in res] # print(data) @@ -204,7 +238,7 @@ class OracleKVStorage(BaseKVStorage): return None async def filter_keys(self, keys: list[str]) -> set[str]: - """过滤掉重复内容""" + """remove duplicated""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) ) @@ -271,13 +305,26 @@ class OracleKVStorage(BaseKVStorage): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] data = { - "check_id": k, "id": k, "content": v["content"], "workspace": self.db.workspace, } # print(merge_sql) await self.db.execute(merge_sql, data) + + if self.namespace == "llm_response_cache": + for mode, items in data.items(): + for k, v in items.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "cache_mode": mode, + } + + await self.db.execute(upsert_sql, _data) return left_data async def index_done_callback(self): @@ -285,8 +332,99 @@ class OracleKVStorage(BaseKVStorage): logger.info("full doc and chunk data had been saved into oracle db!") +@dataclass +class OracleDocStatusStorage(DocStatusStorage): + """Oracle implementation of document status storage""" + # should pass db object to self.db + db: OracleDB = None + meta_fields = None + + def __post_init__(self): + pass + + async def filter_keys(self, ids: list[str]) -> set[str]: + """Return keys that don't exist in storage""" + SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( + ids = ",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + res = await self.db.query(SQL, params, True) + # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. + if res: + existed = set([element["id"] for element in res]) + return set(ids) - existed + else: + return set(ids) + + async def get_status_counts(self) -> Dict[str, int]: + """Get counts of documents in each status""" + SQL = SQL_TEMPLATES["get_status_counts"] + params = {"workspace": self.db.workspace} + res = await self.db.query(SQL, params, True) + # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] + counts = {} + for doc in res: + counts[doc["status"]] = doc["count"] + return counts + + async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]: + """Get all documents by status""" + SQL = SQL_TEMPLATES["get_docs_by_status"] + params = {"workspace": self.db.workspace, "status": status} + res = await self.db.query(SQL, params, True) + # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] + # Converting to be a dict + return { + element["id"]: DocProcessingStatus( + #content_summary=element["content_summary"], + content_summary = "", + content_length=element["CONTENT_LENGTH"], + status=element["STATUS"], + created_at=element["CREATETIME"], + updated_at=element["UPDATETIME"], + chunks_count=-1, + #chunks_count=element["chunks_count"], + ) + for element in res + } + + async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all failed documents""" + return await self.get_docs_by_status(DocStatus.FAILED) + + async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all pending documents""" + return await self.get_docs_by_status(DocStatus.PENDING) + + async def index_done_callback(self): + """Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here""" + logger.info("Doc status had been saved into ORACLE db!") + + async def upsert(self, data: dict[str, dict]): + """Update or insert document status + + Args: + data: Dictionary of document IDs and their status data + """ + SQL = SQL_TEMPLATES["merge_doc_status"] + for k, v in data.items(): + # chunks_count is optional + params = { + "workspace": self.db.workspace, + "id": k, + "content_summary": v["content_summary"], + "content_length": v["content_length"], + "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, + "status": v["status"], + } + await self.db.execute(SQL, params) + return data + + @dataclass class OracleVectorDBStorage(BaseVectorStorage): + # should pass db object to self.db + db: OracleDB = None cosine_better_than_threshold: float = 0.2 def __post_init__(self): @@ -564,13 +702,18 @@ N_T = { TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( - id varchar(256)PRIMARY KEY, + id varchar(256), workspace varchar(1024), doc_name varchar(1024), content CLOB, meta JSON, + content_summary varchar(1024), + content_length NUMBER, + status varchar(256), + chunks_count NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL + updatetime TIMESTAMP DEFAULT NULL, + error varchar(4096) )""" }, "LIGHTRAG_DOC_CHUNKS": { @@ -618,10 +761,16 @@ TABLES = { }, "LIGHTRAG_LLM_CACHE": { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - id varchar(256) PRIMARY KEY, - send clob, - return clob, - model varchar(1024), + id varchar(256) PRIMARY KEY, + workspace varchar(1024), + cache_mode varchar(256), + model_name varchar(256), + original_prompt clob, + return_value clob, + embedding CLOB, + embedding_shape NUMBER, + embedding_min NUMBER, + embedding_max NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL )""" @@ -647,22 +796,70 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + + "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", + + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", + + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", + "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace) values(:id,:content,:workspace) - """, + + "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = :id and a.workspace = :workspace) + WHEN NOT MATCHED THEN + INSERT(id,content,workspace) values(:id,:content,:workspace)""", + "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + USING DUAL + ON (a.id = :check_id) + WHEN NOT MATCHED THEN + INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + + "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a + USING DUAL + ON (a.id = :id) + WHEN NOT MATCHED THEN + INSERT (workspace,id,original_prompt,return_value,cache_mode) + VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode) + WHEN MATCHED THEN UPDATE + SET original_prompt = :original_prompt, + return_value = :return_value, + cache_mode = :cache_mode, + updatetime = SYSDATE""", + + "get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})", + + "get_status_counts": """SELECT status as "status", COUNT(1) as "count" + FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""", + + "get_docs_by_status": """select content_length,status, + TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime + from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""", + + "merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = :id and a.workspace = :workspace) + WHEN NOT MATCHED THEN + INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status) + WHEN MATCHED THEN UPDATE + SET content_summary = :content_summary, + content_length = :content_length, + chunks_count = :chunks_count, + status = :status, + updatetime = SYSDATE""", + # SQL for VectorStorage "entities": """SELECT name as entity_name FROM (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance @@ -714,16 +911,22 @@ SQL_TEMPLATES = { COLUMNS (a.name as source_name,b.name as target_name))""", "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL - ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.name=:name) WHEN NOT MATCHED THEN INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) - values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL - ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cbe49da2..b6d5238e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -79,6 +79,7 @@ Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") +OracleDocStatusStorage = lazy_external_import(".kg.oracle_impl", "OracleDocStatusStorage") MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") @@ -290,6 +291,7 @@ class LightRAG: # kv storage "JsonKVStorage": JsonKVStorage, "OracleKVStorage": OracleKVStorage, + "OracleDocStatusStorage":OracleDocStatusStorage, "MongoKVStorage": MongoKVStorage, "TiDBKVStorage": TiDBKVStorage, # vector storage diff --git a/lightrag/operate.py b/lightrag/operate.py index b2c4d215..45ba9656 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -59,13 +59,15 @@ async def _handle_entity_relation_summary( description: str, global_config: dict, ) -> str: + """Handle entity relation summary + For each entity or relation, input is the combined description of already existing description and new description. + If too long, use LLM to summarize. + """ use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] summary_max_tokens = global_config["entity_summary_to_max_tokens"] - language = global_config["addon_params"].get( - "language", PROMPTS["DEFAULT_LANGUAGE"] - ) + language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"]) tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) if len(tokens) < summary_max_tokens: # No need for summary @@ -139,6 +141,7 @@ async def _merge_nodes_then_upsert( knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): + """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] @@ -319,7 +322,7 @@ async def extract_entities( llm_response_cache.global_config = new_config need_to_restore = True if history_messages: - history = json.dumps(history_messages) + history = json.dumps(history_messages,ensure_ascii=False) _prompt = history + "\n" + input_text else: _prompt = input_text @@ -351,6 +354,11 @@ async def extract_entities( return await use_llm_func(input_text) async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + """"Prpocess a single chunk + Args: + chunk_key_dp (tuple[str, TextChunkSchema]): + ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) + """ nonlocal already_processed, already_entities, already_relations chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] diff --git a/lightrag/utils.py b/lightrag/utils.py index 1f6bf405..56e4191c 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -36,7 +36,7 @@ logger = logging.getLogger("lightrag") def set_logger(log_file: str): logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file) + file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -473,7 +473,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"] llm_model_func = hashing_kv.global_config.get("llm_model_func") current_embedding = await embedding_model_func([prompt])