Enhance the function's robustness

This commit is contained in:
yangdx
2025-04-28 22:52:31 +08:00
parent 90a07b0420
commit 0ecae90002
2 changed files with 216 additions and 83 deletions

View File

@@ -1024,7 +1024,8 @@ class LightRAG:
} }
) )
# Release semphore before entering to merge stage # Semphore was released here
if file_extraction_stage_ok: if file_extraction_stage_ok:
try: try:
# Get chunk_results from entity_relation_task # Get chunk_results from entity_relation_task

View File

@@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import weakref
import asyncio import asyncio
import html import html
@@ -267,119 +268,250 @@ def compute_mdhash_id(content: str, prefix: str = "") -> str:
return prefix + md5(content.encode()).hexdigest() return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int): # Custom exception class
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore""" class QueueFullError(Exception):
"""Raised when the queue is full and the wait times out"""
pass
def final_decro(func): def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
sem = asyncio.Semaphore(max_size)
@wraps(func)
async def wait_func(*args, **kwargs):
async with sem:
result = await func(*args, **kwargs)
return result
return wait_func
return final_decro
def priority_limit_async_func_call(max_size: int):
""" """
Add restriction of maximum concurrent async calls using priority queue. Enhanced priority-limited asynchronous function call decorator
Lower priority value means higher priority.
Args: Args:
max_size: Maximum number of concurrent calls max_size: Maximum number of concurrent calls
max_queue_size: Maximum queue capacity to prevent memory overflow
Returns: Returns:
A decorator that wraps an async function with priority-based concurrency control Decorator function
""" """
def final_decro(func): def final_decro(func):
# Create shared worker pool resources queue = asyncio.PriorityQueue(maxsize=max_queue_size)
queue = asyncio.PriorityQueue()
tasks = set() tasks = set()
lock = asyncio.Lock() lock = asyncio.Lock()
counter = 0 counter = 0
shutdown_event = asyncio.Event()
worker_health_check_task = None
# Worker function that processes tasks from the queue # Track active future objects for cleanup
active_futures = weakref.WeakSet()
# Worker function to process tasks in the queue
async def worker(): async def worker():
"""Worker that processes tasks from the priority queue""" """Worker that processes tasks in the priority queue"""
while True:
# Get task from queue (priority, count, future, args, kwargs)
_, _, future, args, kwargs = await queue.get()
try: try:
# Execute the function while not shutdown_event.is_set():
try:
# Use timeout to get tasks, allowing periodic checking of shutdown signal
try:
priority, count, future, args, kwargs = await asyncio.wait_for(
queue.get(), timeout=1.0
)
except asyncio.TimeoutError:
# Timeout is just to check shutdown signal, continue to next iteration
continue
# If future is cancelled, skip execution
if future.cancelled():
queue.task_done()
continue
try:
# Execute function
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# Set result to future if not already done # If future is not done, set the result
if not future.done(): if not future.done():
future.set_result(result) future.set_result(result)
except asyncio.CancelledError:
if not future.done():
future.cancel()
logger.debug("limit_async: Task cancelled during execution")
except Exception as e: except Exception as e:
# Set exception to future if not already done logger.error(f"limit_async: Error in decorated function: {str(e)}")
if not future.done(): if not future.done():
future.set_exception(e) future.set_exception(e)
finally: finally:
# Mark task as done
queue.task_done() queue.task_done()
except Exception as e:
# Catch all exceptions in worker loop to prevent worker termination
logger.error(f"limit_async: Critical error in worker: {str(e)}")
await asyncio.sleep(0.1) # Prevent high CPU usage
finally:
logger.warning("limit_async: Worker exiting")
async def health_check():
"""Periodically check worker health status and recover"""
try:
while not shutdown_event.is_set():
await asyncio.sleep(5) # Check every 5 seconds
async with lock:
# Directly remove completed tasks from the tasks set
tasks.difference_update({t for t in tasks if t.done()})
# Create new workers if active tasks less than max_size for better performance
active_tasks_count = len(tasks)
workers_needed = max_size - active_tasks_count
if workers_needed > 0:
logger.info(f"limit_async: Creating {workers_needed} new workers")
for _ in range(workers_needed):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
except Exception as e:
logger.error(f"limit_async: Error in health check: {str(e)}")
finally:
logger.warning("limit_async: Health check task exiting")
# Ensure worker tasks are started # Ensure worker tasks are started
async def ensure_workers(): async def ensure_workers():
"""Ensure worker tasks are started""" """Ensure worker tasks and health check are started"""
nonlocal tasks nonlocal tasks, worker_health_check_task
async with lock:
if not tasks: # Use timeout lock to prevent deadlock
# Start worker tasks try:
for _ in range(max_size): lock_acquired = False
try:
# Try to acquire the lock, wait up to 5 seconds
lock_acquired = await asyncio.wait_for(lock.acquire(), timeout=5.0)
except asyncio.TimeoutError:
logger.error("limit_async: Timeout acquiring lock in ensure_workers")
# Even if acquiring the lock times out, continue trying to create workers
try:
# Start the health check task (if not already started)
if worker_health_check_task is None or worker_health_check_task.done():
worker_health_check_task = asyncio.create_task(health_check())
# Directly remove completed tasks from the tasks set
tasks.difference_update({t for t in tasks if t.done()})
# Calculate the number of active tasks
active_tasks_count = len(tasks)
# If active tasks count is less than max_size, create new workers
workers_needed = max_size - active_tasks_count
if workers_needed > 0:
for _ in range(workers_needed):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
finally:
# Ensure the lock is released
if lock_acquired:
lock.release()
except Exception as e:
logger.error(f"limit_async: Error in ensure_workers: {str(e)}")
# Even if an exception occurs, try to create at least one worker
if not any(not t.done() for t in tasks):
task = asyncio.create_task(worker()) task = asyncio.create_task(worker())
tasks.add(task) tasks.add(task)
# Remove task from set when done
task.add_done_callback(tasks.discard) task.add_done_callback(tasks.discard)
async def shutdown():
"""Gracefully shut down all workers and the queue"""
logger.info("limit_async: Shutting down priority queue workers")
# Set the shutdown event
shutdown_event.set()
# Cancel all active futures
for future in list(active_futures):
if not future.done():
future.cancel()
# Wait for the queue to empty
try:
await asyncio.wait_for(queue.join(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning("limit_async: Timeout waiting for queue to empty during shutdown")
# Cancel all worker tasks
for task in list(tasks):
if not task.done():
task.cancel()
# Wait for all tasks to complete
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
# Cancel the health check task
if worker_health_check_task and not worker_health_check_task.done():
worker_health_check_task.cancel()
try:
await worker_health_check_task
except asyncio.CancelledError:
pass
logger.info("limit_async: Priority queue workers shutdown complete")
@wraps(func) @wraps(func)
async def wait_func(*args, _priority=10, _timeout=None, **kwargs): async def wait_func(*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs):
""" """
Execute function with priority-based concurrency control Execute the function with priority-based concurrency control
Args: Args:
*args: Positional arguments to pass to the function *args: Positional arguments passed to the function
priority: Priority of the call (lower value means higher priority) _priority: Call priority (lower values have higher priority)
timeout: Maximum time in seconds to wait for the function to complete _timeout: Maximum time to wait for function completion (in seconds)
**kwargs: Keyword arguments to pass to the function _queue_timeout: Maximum time to wait for entering the queue (in seconds)
**kwargs: Keyword arguments passed to the function
Returns: Returns:
Result of the function call The result of the function call
Raises: Raises:
TimeoutError: If the function call times out TimeoutError: If the function call times out
Any exception raised by the function QueueFullError: If the queue is full and waiting times out
Any exception raised by the decorated function
""" """
# Ensure workers are started # Ensure workers are started
await ensure_workers() await ensure_workers()
# Create future for result # Create a future for the result
future = asyncio.Future() future = asyncio.Future()
active_futures.add(future)
nonlocal counter nonlocal counter
async with lock: async with lock:
current_count = counter current_count = counter
counter += 1 counter += 1
# Put task in queue with priority and monotonic counter # Try to put the task into the queue, supporting timeout
try:
if _queue_timeout is not None:
# Use timeout to wait for queue space
try:
await asyncio.wait_for(
queue.put((_priority, current_count, future, args, kwargs)),
timeout=_queue_timeout
)
except asyncio.TimeoutError:
raise QueueFullError(f"Queue full, timeout after {_queue_timeout} seconds")
else:
# No timeout, may wait indefinitely
await queue.put((_priority, current_count, future, args, kwargs)) await queue.put((_priority, current_count, future, args, kwargs))
except Exception as e:
# Clean up the future
if not future.done():
future.set_exception(e)
active_futures.discard(future)
raise
# Wait for result with optional timeout try:
# Wait for the result, optional timeout
if _timeout is not None: if _timeout is not None:
try: try:
return await asyncio.wait_for(future, _timeout) return await asyncio.wait_for(future, _timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Cancel future if possible # Cancel the future
if not future.done(): if not future.done():
future.cancel() future.cancel()
raise TimeoutError(f"Task timed out after {_timeout} seconds") raise TimeoutError(f"limit_async: Task timed out after {_timeout} seconds")
else: else:
# Wait for result without timeout # Wait for the result without timeout
return await future return await future
finally:
# Clean up the future reference
active_futures.discard(future)
# Add the shutdown method to the decorated function
wait_func.shutdown = shutdown
return wait_func return wait_func