Fix linting
This commit is contained in:
@@ -359,12 +359,10 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
|
|||||||
# Inject chunk configuration
|
# Inject chunk configuration
|
||||||
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
|
||||||
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
|
||||||
|
|
||||||
# Inject LLM cache configuration
|
# Inject LLM cache configuration
|
||||||
args.enable_llm_cache_for_extract = get_env_value(
|
args.enable_llm_cache_for_extract = get_env_value(
|
||||||
"ENABLE_LLM_CACHE_FOR_EXTRACT",
|
"ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool
|
||||||
False,
|
|
||||||
bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
|
||||||
|
@@ -96,11 +96,15 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated):
|
if (is_multiprocess and self.storage_updated.value) or (
|
||||||
|
not is_multiprocess and self.storage_updated
|
||||||
|
):
|
||||||
data_dict = (
|
data_dict = (
|
||||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
)
|
)
|
||||||
logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}")
|
logger.info(
|
||||||
|
f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
|
||||||
|
)
|
||||||
write_json(data_dict, self._file_name)
|
write_json(data_dict, self._file_name)
|
||||||
await clear_all_update_flags(self.namespace)
|
await clear_all_update_flags(self.namespace)
|
||||||
|
|
||||||
|
@@ -44,21 +44,28 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
loaded_data = load_json(self._file_name) or {}
|
loaded_data = load_json(self._file_name) or {}
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.update(loaded_data)
|
self._data.update(loaded_data)
|
||||||
|
|
||||||
# Calculate data count based on namespace
|
# Calculate data count based on namespace
|
||||||
if self.namespace.endswith("cache"):
|
if self.namespace.endswith("cache"):
|
||||||
# For cache namespaces, sum the cache entries across all cache types
|
# For cache namespaces, sum the cache entries across all cache types
|
||||||
data_count = sum(len(first_level_dict) for first_level_dict in loaded_data.values()
|
data_count = sum(
|
||||||
if isinstance(first_level_dict, dict))
|
len(first_level_dict)
|
||||||
|
for first_level_dict in loaded_data.values()
|
||||||
|
if isinstance(first_level_dict, dict)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# For non-cache namespaces, use the original count method
|
# For non-cache namespaces, use the original count method
|
||||||
data_count = len(loaded_data)
|
data_count = len(loaded_data)
|
||||||
|
|
||||||
logger.info(f"Process {os.getpid()} KV load {self.namespace} with {data_count} records")
|
logger.info(
|
||||||
|
f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
|
||||||
|
)
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated):
|
if (is_multiprocess and self.storage_updated.value) or (
|
||||||
|
not is_multiprocess and self.storage_updated
|
||||||
|
):
|
||||||
data_dict = (
|
data_dict = (
|
||||||
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
)
|
)
|
||||||
@@ -66,17 +73,21 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
# Calculate data count based on namespace
|
# Calculate data count based on namespace
|
||||||
if self.namespace.endswith("cache"):
|
if self.namespace.endswith("cache"):
|
||||||
# # For cache namespaces, sum the cache entries across all cache types
|
# # For cache namespaces, sum the cache entries across all cache types
|
||||||
data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values()
|
data_count = sum(
|
||||||
if isinstance(first_level_dict, dict))
|
len(first_level_dict)
|
||||||
|
for first_level_dict in data_dict.values()
|
||||||
|
if isinstance(first_level_dict, dict)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# For non-cache namespaces, use the original count method
|
# For non-cache namespaces, use the original count method
|
||||||
data_count = len(data_dict)
|
data_count = len(data_dict)
|
||||||
|
|
||||||
logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}")
|
logger.info(
|
||||||
|
f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
|
||||||
|
)
|
||||||
write_json(data_dict, self._file_name)
|
write_json(data_dict, self._file_name)
|
||||||
await clear_all_update_flags(self.namespace)
|
await clear_all_update_flags(self.namespace)
|
||||||
|
|
||||||
|
|
||||||
async def get_all(self) -> dict[str, Any]:
|
async def get_all(self) -> dict[str, Any]:
|
||||||
"""Get all data from storage
|
"""Get all data from storage
|
||||||
|
|
||||||
|
@@ -344,6 +344,7 @@ async def set_all_update_flags(namespace: str):
|
|||||||
else:
|
else:
|
||||||
_update_flags[namespace][i] = True
|
_update_flags[namespace][i] = True
|
||||||
|
|
||||||
|
|
||||||
async def clear_all_update_flags(namespace: str):
|
async def clear_all_update_flags(namespace: str):
|
||||||
"""Clear all update flag of namespace indicating all workers need to reload data from files"""
|
"""Clear all update flag of namespace indicating all workers need to reload data from files"""
|
||||||
global _update_flags
|
global _update_flags
|
||||||
@@ -360,6 +361,7 @@ async def clear_all_update_flags(namespace: str):
|
|||||||
else:
|
else:
|
||||||
_update_flags[namespace][i] = False
|
_update_flags[namespace][i] = False
|
||||||
|
|
||||||
|
|
||||||
async def get_all_update_flags_status() -> Dict[str, list]:
|
async def get_all_update_flags_status() -> Dict[str, list]:
|
||||||
"""
|
"""
|
||||||
Get update flags status for all namespaces.
|
Get update flags status for all namespaces.
|
||||||
|
@@ -354,7 +354,9 @@ class LightRAG:
|
|||||||
namespace=make_namespace(
|
namespace=make_namespace(
|
||||||
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
|
||||||
),
|
),
|
||||||
global_config=asdict(self), # Add global_config to ensure cache works properly
|
global_config=asdict(
|
||||||
|
self
|
||||||
|
), # Add global_config to ensure cache works properly
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -706,7 +706,7 @@ class CacheData:
|
|||||||
|
|
||||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
"""Save data to cache, with improved handling for streaming responses and duplicate content.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hashing_kv: The key-value storage for caching
|
hashing_kv: The key-value storage for caching
|
||||||
cache_data: The cache data to save
|
cache_data: The cache data to save
|
||||||
@@ -714,12 +714,12 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|||||||
# Skip if storage is None or content is a streaming response
|
# Skip if storage is None or content is a streaming response
|
||||||
if hashing_kv is None or not cache_data.content:
|
if hashing_kv is None or not cache_data.content:
|
||||||
return
|
return
|
||||||
|
|
||||||
# If content is a streaming response, don't cache it
|
# If content is a streaming response, don't cache it
|
||||||
if hasattr(cache_data.content, "__aiter__"):
|
if hasattr(cache_data.content, "__aiter__"):
|
||||||
logger.debug("Streaming response detected, skipping cache")
|
logger.debug("Streaming response detected, skipping cache")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get existing cache data
|
# Get existing cache data
|
||||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||||
mode_cache = (
|
mode_cache = (
|
||||||
@@ -728,14 +728,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||||
|
|
||||||
# Check if we already have identical content cached
|
# Check if we already have identical content cached
|
||||||
if cache_data.args_hash in mode_cache:
|
if cache_data.args_hash in mode_cache:
|
||||||
existing_content = mode_cache[cache_data.args_hash].get("return")
|
existing_content = mode_cache[cache_data.args_hash].get("return")
|
||||||
if existing_content == cache_data.content:
|
if existing_content == cache_data.content:
|
||||||
logger.info(f"Cache content unchanged for {cache_data.args_hash}, skipping update")
|
logger.info(
|
||||||
|
f"Cache content unchanged for {cache_data.args_hash}, skipping update"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Update cache with new content
|
# Update cache with new content
|
||||||
mode_cache[cache_data.args_hash] = {
|
mode_cache[cache_data.args_hash] = {
|
||||||
"return": cache_data.content,
|
"return": cache_data.content,
|
||||||
@@ -750,7 +752,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|||||||
"embedding_max": cache_data.max_val,
|
"embedding_max": cache_data.max_val,
|
||||||
"original_prompt": cache_data.prompt,
|
"original_prompt": cache_data.prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Only upsert if there's actual new content
|
# Only upsert if there's actual new content
|
||||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user