From 4065a7df92cbe388741b463c4e48d3863920ef87 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 02:07:19 +0800 Subject: [PATCH] Fix linting --- lightrag/api/utils_api.py | 6 ++---- lightrag/kg/json_doc_status_impl.py | 8 +++++-- lightrag/kg/json_kv_impl.py | 33 +++++++++++++++++++---------- lightrag/kg/shared_storage.py | 2 ++ lightrag/lightrag.py | 4 +++- lightrag/utils.py | 16 ++++++++------ 6 files changed, 44 insertions(+), 25 deletions(-) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 9a619f9e..ffe63abd 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -359,12 +359,10 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) - + # Inject LLM cache configuration args.enable_llm_cache_for_extract = get_env_value( - "ENABLE_LLM_CACHE_FOR_EXTRACT", - False, - bool + "ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool ) ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 5b378c17..4502397b 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -96,11 +96,15 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated): + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) - logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}") + logger.info( + f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}" + ) write_json(data_dict, self._file_name) await clear_all_update_flags(self.namespace) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 6c855a25..80abe92e 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -44,21 +44,28 @@ class JsonKVStorage(BaseKVStorage): loaded_data = load_json(self._file_name) or {} async with self._storage_lock: self._data.update(loaded_data) - + # Calculate data count based on namespace if self.namespace.endswith("cache"): # For cache namespaces, sum the cache entries across all cache types - data_count = sum(len(first_level_dict) for first_level_dict in loaded_data.values() - if isinstance(first_level_dict, dict)) + data_count = sum( + len(first_level_dict) + for first_level_dict in loaded_data.values() + if isinstance(first_level_dict, dict) + ) else: # For non-cache namespaces, use the original count method data_count = len(loaded_data) - - logger.info(f"Process {os.getpid()} KV load {self.namespace} with {data_count} records") + + logger.info( + f"Process {os.getpid()} KV load {self.namespace} with {data_count} records" + ) async def index_done_callback(self) -> None: async with self._storage_lock: - if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated): + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) @@ -66,17 +73,21 @@ class JsonKVStorage(BaseKVStorage): # Calculate data count based on namespace if self.namespace.endswith("cache"): # # For cache namespaces, sum the cache entries across all cache types - data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values() - if isinstance(first_level_dict, dict)) + data_count = sum( + len(first_level_dict) + for first_level_dict in data_dict.values() + if isinstance(first_level_dict, dict) + ) else: # For non-cache namespaces, use the original count method data_count = len(data_dict) - - logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}") + + logger.info( + f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}" + ) write_json(data_dict, self._file_name) await clear_all_update_flags(self.namespace) - async def get_all(self) -> dict[str, Any]: """Get all data from storage diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 9ce04d23..63ff1f0d 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -344,6 +344,7 @@ async def set_all_update_flags(namespace: str): else: _update_flags[namespace][i] = True + async def clear_all_update_flags(namespace: str): """Clear all update flag of namespace indicating all workers need to reload data from files""" global _update_flags @@ -360,6 +361,7 @@ async def clear_all_update_flags(namespace: str): else: _update_flags[namespace][i] = False + async def get_all_update_flags_status() -> Dict[str, list]: """ Get update flags status for all namespaces. diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a91aa6fa..ceb47a01 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -354,7 +354,9 @@ class LightRAG: namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), - global_config=asdict(self), # Add global_config to ensure cache works properly + global_config=asdict( + self + ), # Add global_config to ensure cache works properly embedding_func=self.embedding_func, ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 56548420..e8f79610 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -706,7 +706,7 @@ class CacheData: async def save_to_cache(hashing_kv, cache_data: CacheData): """Save data to cache, with improved handling for streaming responses and duplicate content. - + Args: hashing_kv: The key-value storage for caching cache_data: The cache data to save @@ -714,12 +714,12 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): # Skip if storage is None or content is a streaming response if hashing_kv is None or not cache_data.content: return - + # If content is a streaming response, don't cache it if hasattr(cache_data.content, "__aiter__"): logger.debug("Streaming response detected, skipping cache") return - + # Get existing cache data if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = ( @@ -728,14 +728,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): ) else: mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} - + # Check if we already have identical content cached if cache_data.args_hash in mode_cache: existing_content = mode_cache[cache_data.args_hash].get("return") if existing_content == cache_data.content: - logger.info(f"Cache content unchanged for {cache_data.args_hash}, skipping update") + logger.info( + f"Cache content unchanged for {cache_data.args_hash}, skipping update" + ) return - + # Update cache with new content mode_cache[cache_data.args_hash] = { "return": cache_data.content, @@ -750,7 +752,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "embedding_max": cache_data.max_val, "original_prompt": cache_data.prompt, } - + # Only upsert if there's actual new content await hashing_kv.upsert({cache_data.mode: mode_cache})