diff --git a/README.md b/README.md index 54a84323..ed2a7789 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## 🎉 News +- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-postgres-for-storage). - [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete). - [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise. - [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author. @@ -356,6 +357,11 @@ rag = LightRAG( ``` see test_neo4j.py for a working example. +### Using PostgreSQL for Storage +For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE). +* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac. +* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) + ### Insert Custom KG ```python @@ -602,33 +608,34 @@ if __name__ == "__main__": ### LightRAG init parameters -| **Parameter** | **Type** | **Explanation** | **Default** | -| --- | --- | --- | --- | -| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | -| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` | -| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` | -| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` | -| **log\_level** | | Log level for application runtime | `logging.DEBUG` | -| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | -| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | -| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | -| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` | -| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` | -| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` | -| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` | -| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` | -| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` | -| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` | -| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` | -| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` | -| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` | -| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` | -| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | | -| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | | -| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | -| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` | -| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | -| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | +| **Parameter** | **Type** | **Explanation** | **Default** | +|----------------------------------------------| --- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------| +| **working\_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | +| **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` | +| **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` | +| **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` | +| **log\_level** | | Log level for application runtime | `logging.DEBUG` | +| **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | +| **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | +| **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | +| **entity\_extract\_max\_gleaning** | `int` | Number of loops in the entity extraction process, appending history messages | `1` | +| **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` | +| **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` | +| **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` | +| **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` | +| **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` | +| **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` | +| **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` | +| **llm\_model\_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` | +| **llm\_model\_max\_token\_size** | `int` | Maximum token size for LLM generation (affects entity relation summaries) | `32768` | +| **llm\_model\_max\_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `16` | +| **llm\_model\_kwargs** | `dict` | Additional parameters for LLM generation | | +| **vector\_db\_storage\_cls\_kwargs** | `dict` | Additional parameters for vector database (currently not used) | | +| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | +| **enable\_llm\_cache\_for\_entity\_extract** | `bool` | If `TRUE`, stores LLM results in cache for entity extraction; Good for beginners to debug your application | `FALSE` | +| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` | +| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | +| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | ### Error Handling
@@ -1206,6 +1213,7 @@ curl "http://localhost:9621/health" ``` ## Development +Contribute to the project: [Guide](contributor-readme.MD) ### Running in Development Mode diff --git a/contributor-readme.MD b/contributor-readme.MD new file mode 100644 index 00000000..2168d469 --- /dev/null +++ b/contributor-readme.MD @@ -0,0 +1,12 @@ +# Handy Tips for Developers Who Want to Contribute to the Project +## Pre-commit Hooks +Please ensure you have run pre-commit hooks before committing your changes. +### Guides +1. **Installing Pre-commit Hooks**: + - Install pre-commit using pip: `pip install pre-commit` + - Initialize pre-commit in your repository: `pre-commit install` + - Run pre-commit hooks: `pre-commit run --all-files` + +2. **Pre-commit Hooks Configuration**: + - Create a `.pre-commit-config.yaml` file in the root of your repository. + - Add your hooks to the `.pre-commit-config.yaml`file. diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py index f28480df..d0461d84 100644 --- a/examples/lightrag_zhipu_postgres_demo.py +++ b/examples/lightrag_zhipu_postgres_demo.py @@ -43,6 +43,7 @@ async def main(): llm_model_name="glm-4-flashx", llm_model_max_async=4, llm_model_max_token_size=32768, + enable_llm_cache_for_entity_extract=True, embedding_func=EmbeddingFunc( embedding_dim=768, max_token_size=8192, diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 704fa476..033d63d6 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -151,7 +151,10 @@ class PostgreSQLDB: try: await conn.execute('SET search_path = ag_catalog, "$user", public') await conn.execute(f"""select create_graph('{graph_name}')""") - except asyncpg.exceptions.InvalidSchemaNameError: + except ( + asyncpg.exceptions.InvalidSchemaNameError, + asyncpg.exceptions.UniqueViolationError, + ): pass @@ -160,7 +163,6 @@ class PGKVStorage(BaseKVStorage): db: PostgreSQLDB = None def __post_init__(self): - self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] ################ QUERY METHODS ################ @@ -181,6 +183,19 @@ class PGKVStorage(BaseKVStorage): else: return None + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + """Specifically for llm_response_cache.""" + sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + params = {"workspace": self.db.workspace, mode: mode, "id": id} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(sql, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res + else: + return None + # Query by id async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]: """Get doc_chunks data by id""" @@ -229,33 +244,30 @@ class PGKVStorage(BaseKVStorage): ################ INSERT METHODS ################ async def upsert(self, data: Dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) if self.namespace == "text_chunks": pass elif self.namespace == "full_docs": - for k, v in self._data.items(): + for k, v in data.items(): upsert_sql = SQL_TEMPLATES["upsert_doc_full"] - data = { + _data = { "id": k, "content": v["content"], "workspace": self.db.workspace, } - await self.db.execute(upsert_sql, data) + await self.db.execute(upsert_sql, _data) elif self.namespace == "llm_response_cache": - for mode, items in self._data.items(): + for mode, items in data.items(): for k, v in items.items(): upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] - data = { + _data = { "workspace": self.db.workspace, "id": k, "original_prompt": v["original_prompt"], - "return": v["return"], + "return_value": v["return"], "mode": mode, } - await self.db.execute(upsert_sql, data) - return left_data + await self.db.execute(upsert_sql, _data) async def index_done_callback(self): if self.namespace in ["full_docs", "text_chunks"]: @@ -977,9 +989,6 @@ class PGGraphStorage(BaseGraphStorage): source_node_label = source_node_id.strip('"') target_node_label = target_node_id.strip('"') edge_properties = edge_data - logger.info( - f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}" - ) query = """MATCH (source:`{src_label}`) WITH source @@ -1028,8 +1037,8 @@ TABLES = { doc_name VARCHAR(1024), content TEXT, meta JSONB, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id) )""" }, @@ -1042,8 +1051,8 @@ TABLES = { tokens INTEGER, content TEXT, content_vector VECTOR, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) )""" }, @@ -1054,8 +1063,8 @@ TABLES = { entity_name VARCHAR(255), content TEXT, content_vector VECTOR, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) )""" }, @@ -1067,8 +1076,8 @@ TABLES = { target_id VARCHAR(256), content TEXT, content_vector VECTOR, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) )""" }, @@ -1078,10 +1087,10 @@ TABLES = { id varchar(255) NOT NULL, mode varchar(32) NOT NULL, original_prompt TEXT, - return TEXT, - createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP, - CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id) + return_value TEXT, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, + CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) )""" }, "LIGHTRAG_DOC_STATUS": { @@ -1109,9 +1118,12 @@ SQL_TEMPLATES = { chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 """, - "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode + "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 """, + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 + """, "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) """, @@ -1119,22 +1131,22 @@ SQL_TEMPLATES = { chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) """, - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids}) """, "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace) VALUES ($1, $2, $3) ON CONFLICT (workspace,id) DO UPDATE - SET content = $2, updatetime = CURRENT_TIMESTAMP + SET content = $2, update_time = CURRENT_TIMESTAMP """, - "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,"return",mode) + "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode) VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (workspace,id) DO UPDATE + ON CONFLICT (workspace,mode,id) DO UPDATE SET original_prompt = EXCLUDED.original_prompt, - "return"=EXCLUDED."return", + return_value=EXCLUDED.return_value, mode=EXCLUDED.mode, - updatetime = CURRENT_TIMESTAMP + update_time = CURRENT_TIMESTAMP """, "upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, chunk_order_index, full_doc_id, content, content_vector) @@ -1145,7 +1157,7 @@ SQL_TEMPLATES = { full_doc_id=EXCLUDED.full_doc_id, content = EXCLUDED.content, content_vector=EXCLUDED.content_vector, - updatetime = CURRENT_TIMESTAMP + update_time = CURRENT_TIMESTAMP """, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector) VALUES ($1, $2, $3, $4, $5) @@ -1153,7 +1165,7 @@ SQL_TEMPLATES = { SET entity_name=EXCLUDED.entity_name, content=EXCLUDED.content, content_vector=EXCLUDED.content_vector, - updatetime=CURRENT_TIMESTAMP + update_time=CURRENT_TIMESTAMP """, "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, target_id, content, content_vector) @@ -1162,7 +1174,7 @@ SQL_TEMPLATES = { SET source_id=EXCLUDED.source_id, target_id=EXCLUDED.target_id, content=EXCLUDED.content, - content_vector=EXCLUDED.content_vector, updatetime = CURRENT_TIMESTAMP + content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP """, # SQL for VectorStorage "entities": """SELECT entity_name FROM diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index fc71508c..05de8d9f 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -176,6 +176,8 @@ class LightRAG: vector_db_storage_cls_kwargs: dict = field(default_factory=dict) enable_llm_cache: bool = True + # Sometimes there are some reason the LLM failed at Extracting Entities, and we want to continue without LLM cost, we can use this flag + enable_llm_cache_for_entity_extract: bool = False # extension addon_params: dict = field(default_factory=dict) @@ -402,6 +404,7 @@ class LightRAG: knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, + llm_response_cache=self.llm_response_cache, global_config=asdict(self), ) diff --git a/lightrag/operate.py b/lightrag/operate.py index f21e41ff..b2c4d215 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -253,9 +253,13 @@ async def extract_entities( entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict, + llm_response_cache: BaseKVStorage = None, ) -> Union[BaseGraphStorage, None]: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] + enable_llm_cache_for_entity_extract: bool = global_config[ + "enable_llm_cache_for_entity_extract" + ] ordered_chunks = list(chunks.items()) # add language and example number params to prompt @@ -300,6 +304,52 @@ async def extract_entities( already_entities = 0 already_relations = 0 + async def _user_llm_func_with_cache( + input_text: str, history_messages: list[dict[str, str]] = None + ) -> str: + if enable_llm_cache_for_entity_extract and llm_response_cache: + need_to_restore = False + if ( + global_config["embedding_cache_config"] + and global_config["embedding_cache_config"]["enabled"] + ): + new_config = global_config.copy() + new_config["embedding_cache_config"] = None + new_config["enable_llm_cache"] = True + llm_response_cache.global_config = new_config + need_to_restore = True + if history_messages: + history = json.dumps(history_messages) + _prompt = history + "\n" + input_text + else: + _prompt = input_text + + arg_hash = compute_args_hash(_prompt) + cached_return, _1, _2, _3 = await handle_cache( + llm_response_cache, arg_hash, _prompt, "default" + ) + if need_to_restore: + llm_response_cache.global_config = global_config + if cached_return: + return cached_return + + if history_messages: + res: str = await use_llm_func( + input_text, history_messages=history_messages + ) + else: + res: str = await use_llm_func(input_text) + await save_to_cache( + llm_response_cache, + CacheData(args_hash=arg_hash, content=res, prompt=_prompt), + ) + return res + + if history_messages: + return await use_llm_func(input_text, history_messages=history_messages) + else: + return await use_llm_func(input_text) + async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): nonlocal already_processed, already_entities, already_relations chunk_key = chunk_key_dp[0] @@ -310,17 +360,19 @@ async def extract_entities( **context_base, input_text="{input_text}" ).format(**context_base, input_text=content) - final_result = await use_llm_func(hint_prompt) + final_result = await _user_llm_func_with_cache(hint_prompt) history = pack_user_ass_to_openai_messages(hint_prompt, final_result) for now_glean_index in range(entity_extract_max_gleaning): - glean_result = await use_llm_func(continue_prompt, history_messages=history) + glean_result = await _user_llm_func_with_cache( + continue_prompt, history_messages=history + ) history += pack_user_ass_to_openai_messages(continue_prompt, glean_result) final_result += glean_result if now_glean_index == entity_extract_max_gleaning - 1: break - if_loop_result: str = await use_llm_func( + if_loop_result: str = await _user_llm_func_with_cache( if_loop_prompt, history_messages=history ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() diff --git a/lightrag/utils.py b/lightrag/utils.py index b7c9649a..1f6bf405 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -454,7 +454,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): # For naive mode, only use simple cache matching if mode == "naive": - mode_cache = await hashing_kv.get_by_id(mode) or {} + if exists_func(hashing_kv, "get_by_mode_and_id"): + mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} + else: + mode_cache = await hashing_kv.get_by_id(mode) or {} if args_hash in mode_cache: return mode_cache[args_hash]["return"], None, None, None return None, None, None, None @@ -488,7 +491,10 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): return best_cached_response, None, None, None else: # Use regular cache - mode_cache = await hashing_kv.get_by_id(mode) or {} + if exists_func(hashing_kv, "get_by_mode_and_id"): + mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} + else: + mode_cache = await hashing_kv.get_by_id(mode) or {} if args_hash in mode_cache: return mode_cache[args_hash]["return"], None, None, None @@ -510,7 +516,13 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): if hashing_kv is None or hasattr(cache_data.content, "__aiter__"): return - mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + if exists_func(hashing_kv, "get_by_mode_and_id"): + mode_cache = ( + await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash) + or {} + ) + else: + mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} mode_cache[cache_data.args_hash] = { "return": cache_data.content, @@ -543,3 +555,15 @@ def safe_unicode_decode(content): ) return decoded_content + + +def exists_func(obj, func_name: str) -> bool: + """Check if a function exists in an object or not. + :param obj: + :param func_name: + :return: True / False + """ + if callable(getattr(obj, func_name, None)): + return True + else: + return False