From 1afcbcbfb572c8ef6fcf9a5f8602e774507fb27f Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 29 Apr 2025 00:08:52 +0800 Subject: [PATCH] Fix race condition for health_check and ensure_workers --- lightrag/utils.py | 101 ++++++++++++++++++++-------------------------- 1 file changed, 44 insertions(+), 57 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index da006647..574e3ecd 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -289,9 +289,10 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): def final_decro(func): queue = asyncio.PriorityQueue(maxsize=max_queue_size) tasks = set() - lock = asyncio.Lock() + initialization_lock = asyncio.Lock() counter = 0 shutdown_event = asyncio.Event() + initialized = False # Global initialization flag worker_health_check_task = None # Track active future objects for cleanup @@ -352,76 +353,62 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): while not shutdown_event.is_set(): await asyncio.sleep(5) # Check every 5 seconds - async with lock: - # Directly remove completed tasks from the tasks set - tasks.difference_update({t for t in tasks if t.done()}) + # No longer acquire lock, directly operate on task set + # Use a copy of the task set to avoid concurrent modification + current_tasks = set(tasks) + done_tasks = {t for t in current_tasks if t.done()} + tasks.difference_update(done_tasks) - # Create new workers if active tasks less than max_size for better performance - active_tasks_count = len(tasks) - workers_needed = max_size - active_tasks_count - if workers_needed > 0: - logger.info( - f"limit_async: Creating {workers_needed} new workers" - ) - for _ in range(workers_needed): - task = asyncio.create_task(worker()) - tasks.add(task) - task.add_done_callback(tasks.discard) + # Calculate active tasks count + active_tasks_count = len(tasks) + workers_needed = max_size - active_tasks_count + + if workers_needed > 0: + logger.info( + f"limit_async: Creating {workers_needed} new workers" + ) + new_tasks = set() + for _ in range(workers_needed): + task = asyncio.create_task(worker()) + new_tasks.add(task) + task.add_done_callback(tasks.discard) + # Update task set in one operation + tasks.update(new_tasks) except Exception as e: logger.error(f"limit_async: Error in health check: {str(e)}") finally: logger.warning("limit_async: Health check task exiting") - # Ensure worker tasks are started async def ensure_workers(): - """Ensure worker tasks and health check are started""" - nonlocal tasks, worker_health_check_task + """Ensure worker threads and health check system are available - # Use timeout lock to prevent deadlock - try: - lock_acquired = False - try: - # Try to acquire the lock, wait up to 5 seconds - lock_acquired = await asyncio.wait_for(lock.acquire(), timeout=5.0) - except asyncio.TimeoutError: - logger.error( - "limit_async: Timeout acquiring lock in ensure_workers" - ) - # Even if acquiring the lock times out, continue trying to create workers + This function checks if the worker system is already initialized. + If not, it performs a one-time initialization of all worker threads + and starts the health check system. + """ + nonlocal initialized, worker_health_check_task, tasks - try: - # Start the health check task (if not already started) - if ( - worker_health_check_task is None - or worker_health_check_task.done() - ): - worker_health_check_task = asyncio.create_task(health_check()) + if initialized: + return - # Directly remove completed tasks from the tasks set - tasks.difference_update({t for t in tasks if t.done()}) + async with initialization_lock: + if initialized: + return - # Calculate the number of active tasks - active_tasks_count = len(tasks) + logger.info("limit_async: Initializing worker system") - # If active tasks count is less than max_size, create new workers - workers_needed = max_size - active_tasks_count - if workers_needed > 0: - for _ in range(workers_needed): - task = asyncio.create_task(worker()) - tasks.add(task) - task.add_done_callback(tasks.discard) - finally: - # Ensure the lock is released - if lock_acquired: - lock.release() - except Exception as e: - logger.error(f"limit_async: Error in ensure_workers: {str(e)}") - # Even if an exception occurs, try to create at least one worker - if not any(not t.done() for t in tasks): + # Create initial worker tasks + for _ in range(max_size): task = asyncio.create_task(worker()) tasks.add(task) task.add_done_callback(tasks.discard) + # Start health check + worker_health_check_task = asyncio.create_task(health_check()) + + initialized = True + logger.info("limit_async: Worker system initialized") + async def shutdown(): """Gracefully shut down all workers and the queue""" logger.info("limit_async: Shutting down priority queue workers") @@ -480,7 +467,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): QueueFullError: If the queue is full and waiting times out Any exception raised by the decorated function """ - # Ensure workers are started + # Ensure worker system is initialized await ensure_workers() # Create a future for the result @@ -488,7 +475,7 @@ def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): active_futures.add(future) nonlocal counter - async with lock: + async with initialization_lock: current_count = counter counter += 1