Fix race condition for health_check and ensure_workers

This commit is contained in:
yangdx
2025-04-29 00:08:52 +08:00
parent 1fc26127d5
commit 1afcbcbfb5

View File

@@ -289,9 +289,10 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
def final_decro(func): def final_decro(func):
queue = asyncio.PriorityQueue(maxsize=max_queue_size) queue = asyncio.PriorityQueue(maxsize=max_queue_size)
tasks = set() tasks = set()
lock = asyncio.Lock() initialization_lock = asyncio.Lock()
counter = 0 counter = 0
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
initialized = False # Global initialization flag
worker_health_check_task = None worker_health_check_task = None
# Track active future objects for cleanup # Track active future objects for cleanup
@@ -352,75 +353,61 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
while not shutdown_event.is_set(): while not shutdown_event.is_set():
await asyncio.sleep(5) # Check every 5 seconds await asyncio.sleep(5) # Check every 5 seconds
async with lock: # No longer acquire lock, directly operate on task set
# Directly remove completed tasks from the tasks set # Use a copy of the task set to avoid concurrent modification
tasks.difference_update({t for t in tasks if t.done()}) current_tasks = set(tasks)
done_tasks = {t for t in current_tasks if t.done()}
tasks.difference_update(done_tasks)
# Create new workers if active tasks less than max_size for better performance # Calculate active tasks count
active_tasks_count = len(tasks) active_tasks_count = len(tasks)
workers_needed = max_size - active_tasks_count workers_needed = max_size - active_tasks_count
if workers_needed > 0: if workers_needed > 0:
logger.info( logger.info(
f"limit_async: Creating {workers_needed} new workers" f"limit_async: Creating {workers_needed} new workers"
) )
new_tasks = set()
for _ in range(workers_needed): for _ in range(workers_needed):
task = asyncio.create_task(worker()) task = asyncio.create_task(worker())
tasks.add(task) new_tasks.add(task)
task.add_done_callback(tasks.discard) task.add_done_callback(tasks.discard)
# Update task set in one operation
tasks.update(new_tasks)
except Exception as e: except Exception as e:
logger.error(f"limit_async: Error in health check: {str(e)}") logger.error(f"limit_async: Error in health check: {str(e)}")
finally: finally:
logger.warning("limit_async: Health check task exiting") logger.warning("limit_async: Health check task exiting")
# Ensure worker tasks are started
async def ensure_workers(): async def ensure_workers():
"""Ensure worker tasks and health check are started""" """Ensure worker threads and health check system are available
nonlocal tasks, worker_health_check_task
# Use timeout lock to prevent deadlock This function checks if the worker system is already initialized.
try: If not, it performs a one-time initialization of all worker threads
lock_acquired = False and starts the health check system.
try: """
# Try to acquire the lock, wait up to 5 seconds nonlocal initialized, worker_health_check_task, tasks
lock_acquired = await asyncio.wait_for(lock.acquire(), timeout=5.0)
except asyncio.TimeoutError:
logger.error(
"limit_async: Timeout acquiring lock in ensure_workers"
)
# Even if acquiring the lock times out, continue trying to create workers
try: if initialized:
# Start the health check task (if not already started) return
if (
worker_health_check_task is None async with initialization_lock:
or worker_health_check_task.done() if initialized:
): return
logger.info("limit_async: Initializing worker system")
# Create initial worker tasks
for _ in range(max_size):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
# Start health check
worker_health_check_task = asyncio.create_task(health_check()) worker_health_check_task = asyncio.create_task(health_check())
# Directly remove completed tasks from the tasks set initialized = True
tasks.difference_update({t for t in tasks if t.done()}) logger.info("limit_async: Worker system initialized")
# Calculate the number of active tasks
active_tasks_count = len(tasks)
# If active tasks count is less than max_size, create new workers
workers_needed = max_size - active_tasks_count
if workers_needed > 0:
for _ in range(workers_needed):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
finally:
# Ensure the lock is released
if lock_acquired:
lock.release()
except Exception as e:
logger.error(f"limit_async: Error in ensure_workers: {str(e)}")
# Even if an exception occurs, try to create at least one worker
if not any(not t.done() for t in tasks):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
async def shutdown(): async def shutdown():
"""Gracefully shut down all workers and the queue""" """Gracefully shut down all workers and the queue"""
@@ -480,7 +467,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
QueueFullError: If the queue is full and waiting times out QueueFullError: If the queue is full and waiting times out
Any exception raised by the decorated function Any exception raised by the decorated function
""" """
# Ensure workers are started # Ensure worker system is initialized
await ensure_workers() await ensure_workers()
# Create a future for the result # Create a future for the result
@@ -488,7 +475,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
active_futures.add(future) active_futures.add(future)
nonlocal counter nonlocal counter
async with lock: async with initialization_lock:
current_count = counter current_count = counter
counter += 1 counter += 1