From d7045121394468dfd32f58feb74efbac1c6ccb5c Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 1 Mar 2025 05:01:26 +0800 Subject: [PATCH] Refactor shared storage module to improve async handling and naming consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add async support for get_namespace_data • Rename get_update_flags to get_update_flag • Rename set_update_flag to set_all_update_flags • Update docstrings for clarity • Fix typos in log messages --- lightrag/api/routers/document_routes.py | 2 +- lightrag/kg/shared_storage.py | 20 +++++++++----------- lightrag/operate.py | 2 +- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 3fdbdf9e..ab5aff96 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -667,7 +667,7 @@ def create_document_routes( try: from lightrag.kg.shared_storage import get_namespace_data - pipeline_status = get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data("pipeline_status") # Convert to regular dict if it's a Manager.dict status_dict = dict(pipeline_status) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 5f795f0f..940d0e7b 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -18,13 +18,12 @@ def direct_log(message, level="INFO"): T = TypeVar('T') class UnifiedLock(Generic[T]): - """统一的锁包装类,提供同步和异步的统一接口""" + """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): self._lock = lock self._is_async = is_async async def __aenter__(self) -> 'UnifiedLock[T]': - """异步上下文管理器入口""" if self._is_async: await self._lock.acquire() else: @@ -32,21 +31,20 @@ class UnifiedLock(Generic[T]): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" if self._is_async: self._lock.release() else: self._lock.release() def __enter__(self) -> 'UnifiedLock[T]': - """同步上下文管理器入口(仅用于向后兼容)""" + """For backward compatibility""" if self._is_async: raise RuntimeError("Use 'async with' for asyncio.Lock") self._lock.acquire() return self def __exit__(self, exc_type, exc_val, exc_tb): - """同步上下文管理器出口(仅用于向后兼容)""" + """For backward compatibility""" if self._is_async: raise RuntimeError("Use 'async with' for asyncio.Lock") self._lock.release() @@ -153,10 +151,10 @@ async def initialize_pipeline_namespace(): direct_log(f"Process {os.getpid()} Pipeline namespace initialized") -async def get_update_flags(namespace: str): +async def get_update_flag(namespace: str): """ - Create a updated flags of a specific namespace. - Caller must store the dict object locally and use it to determine when to update the storage. + Create a namespace's update flag for a workers. + Returen the update flag to caller for referencing or reset. """ global _update_flags if _update_flags is None: @@ -178,8 +176,8 @@ async def get_update_flags(namespace: str): _update_flags[namespace].append(new_update_flag) return new_update_flag -async def set_update_flag(namespace: str): - """Set all update flag of namespace to indicate storage needs updating""" +async def set_all_update_flags(namespace: str): + """Set all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") @@ -212,7 +210,7 @@ def try_initialize_namespace(namespace: str) -> bool: ) return True direct_log( - f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]" + f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" ) return False diff --git a/lightrag/operate.py b/lightrag/operate.py index 5db5b5c6..e90854a0 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -338,7 +338,7 @@ async def extract_entities( ) -> None: from lightrag.kg.shared_storage import get_namespace_data - pipeline_status = get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data("pipeline_status") use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[