refactor: migrate synchronous locks to async locks for improved concurrency
• Add UnifiedLock wrapper class • Convert with blocks to async with
This commit is contained in:
@@ -3,7 +3,7 @@ import sys
|
||||
import asyncio
|
||||
from multiprocessing.synchronize import Lock as ProcessLock
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union, TypeVar, Generic
|
||||
|
||||
|
||||
# Define a direct print function for critical logs that must be visible in all processes
|
||||
@@ -15,6 +15,43 @@ def direct_log(message, level="INFO"):
|
||||
print(f"{level}: {message}", file=sys.stderr, flush=True)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class UnifiedLock(Generic[T]):
|
||||
"""统一的锁包装类,提供同步和异步的统一接口"""
|
||||
def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
|
||||
self._lock = lock
|
||||
self._is_async = is_async
|
||||
|
||||
async def __aenter__(self) -> 'UnifiedLock[T]':
|
||||
"""异步上下文管理器入口"""
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
if self._is_async:
|
||||
self._lock.release()
|
||||
else:
|
||||
self._lock.release()
|
||||
|
||||
def __enter__(self) -> 'UnifiedLock[T]':
|
||||
"""同步上下文管理器入口(仅用于向后兼容)"""
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for asyncio.Lock")
|
||||
self._lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""同步上下文管理器出口(仅用于向后兼容)"""
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for asyncio.Lock")
|
||||
self._lock.release()
|
||||
|
||||
|
||||
LockType = Union[ProcessLock, asyncio.Lock]
|
||||
|
||||
is_multiprocess = None
|
||||
@@ -117,26 +154,21 @@ async def get_update_flags(namespace: str):
|
||||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
if is_multiprocess:
|
||||
with _global_lock:
|
||||
if namespace not in _update_flags:
|
||||
if _manager is not None:
|
||||
_update_flags[namespace] = _manager.list()
|
||||
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
|
||||
|
||||
if _manager is not None:
|
||||
new_update_flag = _manager.Value('b', False)
|
||||
_update_flags[namespace].append(new_update_flag)
|
||||
return new_update_flag
|
||||
else:
|
||||
async with _global_lock:
|
||||
if namespace not in _update_flags:
|
||||
async with get_storage_lock():
|
||||
if namespace not in _update_flags:
|
||||
if is_multiprocess and _manager is not None:
|
||||
_update_flags[namespace] = _manager.list()
|
||||
else:
|
||||
_update_flags[namespace] = []
|
||||
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
|
||||
|
||||
direct_log(f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]")
|
||||
|
||||
if is_multiprocess and _manager is not None:
|
||||
new_update_flag = _manager.Value('b', False)
|
||||
else:
|
||||
new_update_flag = False
|
||||
_update_flags[namespace].append(new_update_flag)
|
||||
return new_update_flag
|
||||
|
||||
_update_flags[namespace].append(new_update_flag)
|
||||
return new_update_flag
|
||||
|
||||
async def set_update_flag(namespace: str):
|
||||
"""Set all update flag of namespace to indicate storage needs updating"""
|
||||
@@ -144,19 +176,14 @@ async def set_update_flag(namespace: str):
|
||||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
if is_multiprocess:
|
||||
with _global_lock:
|
||||
if namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
# Update flags for multiprocess mode
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
async with get_storage_lock():
|
||||
if namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
# Update flags for both modes
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
if is_multiprocess:
|
||||
_update_flags[namespace][i].value = True
|
||||
else:
|
||||
async with _global_lock:
|
||||
if namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
# Update flags for single process mode
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
else:
|
||||
_update_flags[namespace][i] = True
|
||||
|
||||
|
||||
@@ -182,9 +209,12 @@ def try_initialize_namespace(namespace: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_storage_lock() -> LockType:
|
||||
"""return storage lock for data consistency"""
|
||||
return _global_lock
|
||||
def get_storage_lock() -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
return UnifiedLock(
|
||||
lock=_global_lock,
|
||||
is_async=not is_multiprocess
|
||||
)
|
||||
|
||||
|
||||
async def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||
@@ -196,14 +226,11 @@ async def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||
)
|
||||
raise ValueError("Shared dictionaries not initialized")
|
||||
|
||||
if is_multiprocess:
|
||||
with _global_lock:
|
||||
if namespace not in _shared_dicts:
|
||||
if _manager is not None:
|
||||
_shared_dicts[namespace] = _manager.dict()
|
||||
else:
|
||||
async with _global_lock:
|
||||
if namespace not in _shared_dicts:
|
||||
async with get_storage_lock():
|
||||
if namespace not in _shared_dicts:
|
||||
if is_multiprocess and _manager is not None:
|
||||
_shared_dicts[namespace] = _manager.dict()
|
||||
else:
|
||||
_shared_dicts[namespace] = {}
|
||||
|
||||
return _shared_dicts[namespace]
|
||||
|
Reference in New Issue
Block a user