Fix linting

This commit is contained in:
yangdx
2025-03-10 02:07:19 +08:00
parent 14e1b31d1c
commit 4065a7df92
6 changed files with 44 additions and 25 deletions

View File

@@ -359,12 +359,10 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
# Inject chunk configuration # Inject chunk configuration
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
# Inject LLM cache configuration # Inject LLM cache configuration
args.enable_llm_cache_for_extract = get_env_value( args.enable_llm_cache_for_extract = get_env_value(
"ENABLE_LLM_CACHE_FOR_EXTRACT", "ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool
False,
bool
) )
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name

View File

@@ -96,11 +96,15 @@ class JsonDocStatusStorage(DocStatusStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
async with self._storage_lock: async with self._storage_lock:
if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated): if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
data_dict = ( data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data dict(self._data) if hasattr(self._data, "_getvalue") else self._data
) )
logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}") logger.info(
f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
)
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
await clear_all_update_flags(self.namespace) await clear_all_update_flags(self.namespace)

View File

@@ -44,21 +44,28 @@ class JsonKVStorage(BaseKVStorage):
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
async with self._storage_lock: async with self._storage_lock:
self._data.update(loaded_data) self._data.update(loaded_data)
# Calculate data count based on namespace # Calculate data count based on namespace
if self.namespace.endswith("cache"): if self.namespace.endswith("cache"):
# For cache namespaces, sum the cache entries across all cache types # For cache namespaces, sum the cache entries across all cache types
data_count = sum(len(first_level_dict) for first_level_dict in loaded_data.values() data_count = sum(
if isinstance(first_level_dict, dict)) len(first_level_dict)
for first_level_dict in loaded_data.values()
if isinstance(first_level_dict, dict)
)
else: else:
# For non-cache namespaces, use the original count method # For non-cache namespaces, use the original count method
data_count = len(loaded_data) data_count = len(loaded_data)
logger.info(f"Process {os.getpid()} KV load {self.namespace} with {data_count} records") logger.info(
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
)
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
async with self._storage_lock: async with self._storage_lock:
if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated): if (is_multiprocess and self.storage_updated.value) or (
not is_multiprocess and self.storage_updated
):
data_dict = ( data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data dict(self._data) if hasattr(self._data, "_getvalue") else self._data
) )
@@ -66,17 +73,21 @@ class JsonKVStorage(BaseKVStorage):
# Calculate data count based on namespace # Calculate data count based on namespace
if self.namespace.endswith("cache"): if self.namespace.endswith("cache"):
# # For cache namespaces, sum the cache entries across all cache types # # For cache namespaces, sum the cache entries across all cache types
data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values() data_count = sum(
if isinstance(first_level_dict, dict)) len(first_level_dict)
for first_level_dict in data_dict.values()
if isinstance(first_level_dict, dict)
)
else: else:
# For non-cache namespaces, use the original count method # For non-cache namespaces, use the original count method
data_count = len(data_dict) data_count = len(data_dict)
logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}") logger.info(
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
)
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
await clear_all_update_flags(self.namespace) await clear_all_update_flags(self.namespace)
async def get_all(self) -> dict[str, Any]: async def get_all(self) -> dict[str, Any]:
"""Get all data from storage """Get all data from storage

View File

@@ -344,6 +344,7 @@ async def set_all_update_flags(namespace: str):
else: else:
_update_flags[namespace][i] = True _update_flags[namespace][i] = True
async def clear_all_update_flags(namespace: str): async def clear_all_update_flags(namespace: str):
"""Clear all update flag of namespace indicating all workers need to reload data from files""" """Clear all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags global _update_flags
@@ -360,6 +361,7 @@ async def clear_all_update_flags(namespace: str):
else: else:
_update_flags[namespace][i] = False _update_flags[namespace][i] = False
async def get_all_update_flags_status() -> Dict[str, list]: async def get_all_update_flags_status() -> Dict[str, list]:
""" """
Get update flags status for all namespaces. Get update flags status for all namespaces.

View File

@@ -354,7 +354,9 @@ class LightRAG:
namespace=make_namespace( namespace=make_namespace(
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
), ),
global_config=asdict(self), # Add global_config to ensure cache works properly global_config=asdict(
self
), # Add global_config to ensure cache works properly
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )

View File

@@ -706,7 +706,7 @@ class CacheData:
async def save_to_cache(hashing_kv, cache_data: CacheData): async def save_to_cache(hashing_kv, cache_data: CacheData):
"""Save data to cache, with improved handling for streaming responses and duplicate content. """Save data to cache, with improved handling for streaming responses and duplicate content.
Args: Args:
hashing_kv: The key-value storage for caching hashing_kv: The key-value storage for caching
cache_data: The cache data to save cache_data: The cache data to save
@@ -714,12 +714,12 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
# Skip if storage is None or content is a streaming response # Skip if storage is None or content is a streaming response
if hashing_kv is None or not cache_data.content: if hashing_kv is None or not cache_data.content:
return return
# If content is a streaming response, don't cache it # If content is a streaming response, don't cache it
if hasattr(cache_data.content, "__aiter__"): if hasattr(cache_data.content, "__aiter__"):
logger.debug("Streaming response detected, skipping cache") logger.debug("Streaming response detected, skipping cache")
return return
# Get existing cache data # Get existing cache data
if exists_func(hashing_kv, "get_by_mode_and_id"): if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = ( mode_cache = (
@@ -728,14 +728,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
) )
else: else:
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
# Check if we already have identical content cached # Check if we already have identical content cached
if cache_data.args_hash in mode_cache: if cache_data.args_hash in mode_cache:
existing_content = mode_cache[cache_data.args_hash].get("return") existing_content = mode_cache[cache_data.args_hash].get("return")
if existing_content == cache_data.content: if existing_content == cache_data.content:
logger.info(f"Cache content unchanged for {cache_data.args_hash}, skipping update") logger.info(
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
)
return return
# Update cache with new content # Update cache with new content
mode_cache[cache_data.args_hash] = { mode_cache[cache_data.args_hash] = {
"return": cache_data.content, "return": cache_data.content,
@@ -750,7 +752,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
"embedding_max": cache_data.max_val, "embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt, "original_prompt": cache_data.prompt,
} }
# Only upsert if there's actual new content # Only upsert if there's actual new content
await hashing_kv.upsert({cache_data.mode: mode_cache}) await hashing_kv.upsert({cache_data.mode: mode_cache})