Fix race condition for health_check and ensure_workers

This commit is contained in:
yangdx
2025-04-29 00:08:52 +08:00
parent 1fc26127d5
commit 1afcbcbfb5

View File

@@ -289,9 +289,10 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
def final_decro(func):
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
tasks = set()
lock = asyncio.Lock()
initialization_lock = asyncio.Lock()
counter = 0
shutdown_event = asyncio.Event()
initialized = False # Global initialization flag
worker_health_check_task = None
# Track active future objects for cleanup
@@ -352,76 +353,62 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
while not shutdown_event.is_set():
await asyncio.sleep(5) # Check every 5 seconds
async with lock:
# Directly remove completed tasks from the tasks set
tasks.difference_update({t for t in tasks if t.done()})
# No longer acquire lock, directly operate on task set
# Use a copy of the task set to avoid concurrent modification
current_tasks = set(tasks)
done_tasks = {t for t in current_tasks if t.done()}
tasks.difference_update(done_tasks)
# Create new workers if active tasks less than max_size for better performance
active_tasks_count = len(tasks)
workers_needed = max_size - active_tasks_count
if workers_needed > 0:
logger.info(
f"limit_async: Creating {workers_needed} new workers"
)
for _ in range(workers_needed):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
# Calculate active tasks count
active_tasks_count = len(tasks)
workers_needed = max_size - active_tasks_count
if workers_needed > 0:
logger.info(
f"limit_async: Creating {workers_needed} new workers"
)
new_tasks = set()
for _ in range(workers_needed):
task = asyncio.create_task(worker())
new_tasks.add(task)
task.add_done_callback(tasks.discard)
# Update task set in one operation
tasks.update(new_tasks)
except Exception as e:
logger.error(f"limit_async: Error in health check: {str(e)}")
finally:
logger.warning("limit_async: Health check task exiting")
# Ensure worker tasks are started
async def ensure_workers():
"""Ensure worker tasks and health check are started"""
nonlocal tasks, worker_health_check_task
"""Ensure worker threads and health check system are available
# Use timeout lock to prevent deadlock
try:
lock_acquired = False
try:
# Try to acquire the lock, wait up to 5 seconds
lock_acquired = await asyncio.wait_for(lock.acquire(), timeout=5.0)
except asyncio.TimeoutError:
logger.error(
"limit_async: Timeout acquiring lock in ensure_workers"
)
# Even if acquiring the lock times out, continue trying to create workers
This function checks if the worker system is already initialized.
If not, it performs a one-time initialization of all worker threads
and starts the health check system.
"""
nonlocal initialized, worker_health_check_task, tasks
try:
# Start the health check task (if not already started)
if (
worker_health_check_task is None
or worker_health_check_task.done()
):
worker_health_check_task = asyncio.create_task(health_check())
if initialized:
return
# Directly remove completed tasks from the tasks set
tasks.difference_update({t for t in tasks if t.done()})
async with initialization_lock:
if initialized:
return
# Calculate the number of active tasks
active_tasks_count = len(tasks)
logger.info("limit_async: Initializing worker system")
# If active tasks count is less than max_size, create new workers
workers_needed = max_size - active_tasks_count
if workers_needed > 0:
for _ in range(workers_needed):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
finally:
# Ensure the lock is released
if lock_acquired:
lock.release()
except Exception as e:
logger.error(f"limit_async: Error in ensure_workers: {str(e)}")
# Even if an exception occurs, try to create at least one worker
if not any(not t.done() for t in tasks):
# Create initial worker tasks
for _ in range(max_size):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
# Start health check
worker_health_check_task = asyncio.create_task(health_check())
initialized = True
logger.info("limit_async: Worker system initialized")
async def shutdown():
"""Gracefully shut down all workers and the queue"""
logger.info("limit_async: Shutting down priority queue workers")
@@ -480,7 +467,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
QueueFullError: If the queue is full and waiting times out
Any exception raised by the decorated function
"""
# Ensure workers are started
# Ensure worker system is initialized
await ensure_workers()
# Create a future for the result
@@ -488,7 +475,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
active_futures.add(future)
nonlocal counter
async with lock:
async with initialization_lock:
current_count = counter
counter += 1