From 0ecae90002d52592ca4b0640062e54d8dd7d8f56 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 28 Apr 2025 22:52:31 +0800 Subject: [PATCH] Enhance the function's robustness --- lightrag/lightrag.py | 3 +- lightrag/utils.py | 296 +++++++++++++++++++++++++++++++------------ 2 files changed, 216 insertions(+), 83 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9b22a9b8..7a79da31 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1024,7 +1024,8 @@ class LightRAG: } ) - # Release semphore before entering to merge stage + # Semphore was released here + if file_extraction_stage_ok: try: # Get chunk_results from entity_relation_task diff --git a/lightrag/utils.py b/lightrag/utils.py index 054cd777..77314053 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,4 +1,5 @@ from __future__ import annotations +import weakref import asyncio import html @@ -267,122 +268,253 @@ def compute_mdhash_id(content: str, prefix: str = "") -> str: return prefix + md5(content.encode()).hexdigest() -def limit_async_func_call(max_size: int): - """Add restriction of maximum concurrent async calls using asyncio.Semaphore""" +# Custom exception class +class QueueFullError(Exception): + """Raised when the queue is full and the wait times out""" + pass - def final_decro(func): - sem = asyncio.Semaphore(max_size) - - @wraps(func) - async def wait_func(*args, **kwargs): - async with sem: - result = await func(*args, **kwargs) - return result - - return wait_func - - return final_decro - - -def priority_limit_async_func_call(max_size: int): +def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000): """ - Add restriction of maximum concurrent async calls using priority queue. - Lower priority value means higher priority. - + Enhanced priority-limited asynchronous function call decorator + Args: max_size: Maximum number of concurrent calls - + max_queue_size: Maximum queue capacity to prevent memory overflow Returns: - A decorator that wraps an async function with priority-based concurrency control + Decorator function """ - def final_decro(func): - # Create shared worker pool resources - queue = asyncio.PriorityQueue() + queue = asyncio.PriorityQueue(maxsize=max_queue_size) tasks = set() lock = asyncio.Lock() counter = 0 - - # Worker function that processes tasks from the queue + shutdown_event = asyncio.Event() + worker_health_check_task = None + + # Track active future objects for cleanup + active_futures = weakref.WeakSet() + + # Worker function to process tasks in the queue async def worker(): - """Worker that processes tasks from the priority queue""" - while True: - # Get task from queue (priority, count, future, args, kwargs) - _, _, future, args, kwargs = await queue.get() - try: - # Execute the function - result = await func(*args, **kwargs) - # Set result to future if not already done - if not future.done(): - future.set_result(result) - except Exception as e: - # Set exception to future if not already done - if not future.done(): - future.set_exception(e) - finally: - # Mark task as done - queue.task_done() - + """Worker that processes tasks in the priority queue""" + try: + while not shutdown_event.is_set(): + try: + # Use timeout to get tasks, allowing periodic checking of shutdown signal + try: + priority, count, future, args, kwargs = await asyncio.wait_for( + queue.get(), timeout=1.0 + ) + except asyncio.TimeoutError: + # Timeout is just to check shutdown signal, continue to next iteration + continue + + # If future is cancelled, skip execution + if future.cancelled(): + queue.task_done() + continue + + try: + # Execute function + result = await func(*args, **kwargs) + # If future is not done, set the result + if not future.done(): + future.set_result(result) + except asyncio.CancelledError: + if not future.done(): + future.cancel() + logger.debug("limit_async: Task cancelled during execution") + except Exception as e: + logger.error(f"limit_async: Error in decorated function: {str(e)}") + if not future.done(): + future.set_exception(e) + finally: + queue.task_done() + except Exception as e: + # Catch all exceptions in worker loop to prevent worker termination + logger.error(f"limit_async: Critical error in worker: {str(e)}") + await asyncio.sleep(0.1) # Prevent high CPU usage + finally: + logger.warning("limit_async: Worker exiting") + + async def health_check(): + """Periodically check worker health status and recover""" + try: + while not shutdown_event.is_set(): + await asyncio.sleep(5) # Check every 5 seconds + + async with lock: + # Directly remove completed tasks from the tasks set + tasks.difference_update({t for t in tasks if t.done()}) + + # Create new workers if active tasks less than max_size for better performance + active_tasks_count = len(tasks) + workers_needed = max_size - active_tasks_count + if workers_needed > 0: + logger.info(f"limit_async: Creating {workers_needed} new workers") + for _ in range(workers_needed): + task = asyncio.create_task(worker()) + tasks.add(task) + task.add_done_callback(tasks.discard) + except Exception as e: + logger.error(f"limit_async: Error in health check: {str(e)}") + finally: + logger.warning("limit_async: Health check task exiting") + # Ensure worker tasks are started async def ensure_workers(): - """Ensure worker tasks are started""" - nonlocal tasks - async with lock: - if not tasks: - # Start worker tasks - for _ in range(max_size): - task = asyncio.create_task(worker()) - tasks.add(task) - # Remove task from set when done - task.add_done_callback(tasks.discard) + """Ensure worker tasks and health check are started""" + nonlocal tasks, worker_health_check_task + + # Use timeout lock to prevent deadlock + try: + lock_acquired = False + try: + # Try to acquire the lock, wait up to 5 seconds + lock_acquired = await asyncio.wait_for(lock.acquire(), timeout=5.0) + except asyncio.TimeoutError: + logger.error("limit_async: Timeout acquiring lock in ensure_workers") + # Even if acquiring the lock times out, continue trying to create workers + + try: + # Start the health check task (if not already started) + if worker_health_check_task is None or worker_health_check_task.done(): + worker_health_check_task = asyncio.create_task(health_check()) + + # Directly remove completed tasks from the tasks set + tasks.difference_update({t for t in tasks if t.done()}) + + # Calculate the number of active tasks + active_tasks_count = len(tasks) + + # If active tasks count is less than max_size, create new workers + workers_needed = max_size - active_tasks_count + if workers_needed > 0: + for _ in range(workers_needed): + task = asyncio.create_task(worker()) + tasks.add(task) + task.add_done_callback(tasks.discard) + finally: + # Ensure the lock is released + if lock_acquired: + lock.release() + except Exception as e: + logger.error(f"limit_async: Error in ensure_workers: {str(e)}") + # Even if an exception occurs, try to create at least one worker + if not any(not t.done() for t in tasks): + task = asyncio.create_task(worker()) + tasks.add(task) + task.add_done_callback(tasks.discard) + + async def shutdown(): + """Gracefully shut down all workers and the queue""" + logger.info("limit_async: Shutting down priority queue workers") + + # Set the shutdown event + shutdown_event.set() + + # Cancel all active futures + for future in list(active_futures): + if not future.done(): + future.cancel() + + # Wait for the queue to empty + try: + await asyncio.wait_for(queue.join(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning("limit_async: Timeout waiting for queue to empty during shutdown") + + # Cancel all worker tasks + for task in list(tasks): + if not task.done(): + task.cancel() + + # Wait for all tasks to complete + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + # Cancel the health check task + if worker_health_check_task and not worker_health_check_task.done(): + worker_health_check_task.cancel() + try: + await worker_health_check_task + except asyncio.CancelledError: + pass + + logger.info("limit_async: Priority queue workers shutdown complete") @wraps(func) - async def wait_func(*args, _priority=10, _timeout=None, **kwargs): + async def wait_func(*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs): """ - Execute function with priority-based concurrency control - + Execute the function with priority-based concurrency control Args: - *args: Positional arguments to pass to the function - priority: Priority of the call (lower value means higher priority) - timeout: Maximum time in seconds to wait for the function to complete - **kwargs: Keyword arguments to pass to the function - + *args: Positional arguments passed to the function + _priority: Call priority (lower values have higher priority) + _timeout: Maximum time to wait for function completion (in seconds) + _queue_timeout: Maximum time to wait for entering the queue (in seconds) + **kwargs: Keyword arguments passed to the function Returns: - Result of the function call - + The result of the function call Raises: TimeoutError: If the function call times out - Any exception raised by the function + QueueFullError: If the queue is full and waiting times out + Any exception raised by the decorated function """ # Ensure workers are started await ensure_workers() - # Create future for result + # Create a future for the result future = asyncio.Future() + active_futures.add(future) nonlocal counter async with lock: current_count = counter counter += 1 - # Put task in queue with priority and monotonic counter - await queue.put((_priority, current_count, future, args, kwargs)) + # Try to put the task into the queue, supporting timeout + try: + if _queue_timeout is not None: + # Use timeout to wait for queue space + try: + await asyncio.wait_for( + queue.put((_priority, current_count, future, args, kwargs)), + timeout=_queue_timeout + ) + except asyncio.TimeoutError: + raise QueueFullError(f"Queue full, timeout after {_queue_timeout} seconds") + else: + # No timeout, may wait indefinitely + await queue.put((_priority, current_count, future, args, kwargs)) + except Exception as e: + # Clean up the future + if not future.done(): + future.set_exception(e) + active_futures.discard(future) + raise - # Wait for result with optional timeout - if _timeout is not None: - try: - return await asyncio.wait_for(future, _timeout) - except asyncio.TimeoutError: - # Cancel future if possible - if not future.done(): - future.cancel() - raise TimeoutError(f"Task timed out after {_timeout} seconds") - else: - # Wait for result without timeout - return await future + try: + # Wait for the result, optional timeout + if _timeout is not None: + try: + return await asyncio.wait_for(future, _timeout) + except asyncio.TimeoutError: + # Cancel the future + if not future.done(): + future.cancel() + raise TimeoutError(f"limit_async: Task timed out after {_timeout} seconds") + else: + # Wait for the result without timeout + return await future + finally: + # Clean up the future reference + active_futures.discard(future) + + # Add the shutdown method to the decorated function + wait_func.shutdown = shutdown return wait_func - + return final_decro