From b0d87b2e296e30f5705db497fc9204c3f33cdab5 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 31 Jan 2025 15:33:50 +0800 Subject: [PATCH] Fix linting --- lightrag/lightrag.py | 12 ++++++------ lightrag/operate.py | 22 +++++++++++++--------- lightrag/utils.py | 8 ++++++-- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ca82a3d7..f0fb92fd 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -197,12 +197,12 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # Init embedding functions with separate instances for insert and query - self.insert_embedding_func = limit_async_func_call(self.embedding_func_max_async)( - self.embedding_func - ) - self.query_embedding_func = limit_async_func_call(self.embedding_func_max_async_query)( - self.embedding_func - ) + self.insert_embedding_func = limit_async_func_call( + self.embedding_func_max_async + )(self.embedding_func) + self.query_embedding_func = limit_async_func_call( + self.embedding_func_max_async_query + )(self.embedding_func) # Initialize all storages self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( diff --git a/lightrag/operate.py b/lightrag/operate.py index 481a31bb..d88dc7c2 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -352,7 +352,7 @@ async def extract_entities( input_text: str, history_messages: list[dict[str, str]] = None ) -> str: if enable_llm_cache_for_entity_extract and llm_response_cache: - custom_llm = None + custom_llm = None if ( global_config["embedding_cache_config"] and global_config["embedding_cache_config"]["enabled"] @@ -360,10 +360,14 @@ async def extract_entities( new_config = global_config.copy() new_config["embedding_cache_config"] = None new_config["enable_llm_cache"] = True - + # create a llm function with new_config for handle_cache async def custom_llm( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + keyword_extraction=False, + **kwargs, ) -> str: # 合并 new_config 和其他 kwargs,保证其他参数不被覆盖 merged_config = {**kwargs, **new_config} @@ -374,7 +378,7 @@ async def extract_entities( keyword_extraction=keyword_extraction, **merged_config, ) - + if history_messages: history = json.dumps(history_messages, ensure_ascii=False) _prompt = history + "\n" + input_text @@ -383,12 +387,12 @@ async def extract_entities( arg_hash = compute_args_hash(_prompt) cached_return, _1, _2, _3 = await handle_cache( - llm_response_cache, - arg_hash, - _prompt, - "default", + llm_response_cache, + arg_hash, + _prompt, + "default", cache_type="default", - llm=custom_llm + llm=custom_llm, ) if cached_return: logger.debug(f"Found cache for {arg_hash}") diff --git a/lightrag/utils.py b/lightrag/utils.py index daab10b0..e5b3b8d8 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -491,7 +491,9 @@ def dequantize_embedding( return (quantized * scale + min_val).astype(np.float32) -async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None): +async def handle_cache( + hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None +): """Generic cache handling function""" if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"): return None, None, None, None @@ -528,7 +530,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type similarity_threshold=embedding_cache_config["similarity_threshold"], mode=mode, use_llm_check=use_llm_check, - llm_func=llm if (use_llm_check and llm is not None) else (llm_model_func if use_llm_check else None), + llm_func=llm + if (use_llm_check and llm is not None) + else (llm_model_func if use_llm_check else None), original_prompt=prompt if use_llm_check else None, cache_type=cache_type, )