Enhance the function's robustness
This commit is contained in:
@@ -1024,7 +1024,8 @@ class LightRAG:
|
||||
}
|
||||
)
|
||||
|
||||
# Release semphore before entering to merge stage
|
||||
# Semphore was released here
|
||||
|
||||
if file_extraction_stage_ok:
|
||||
try:
|
||||
# Get chunk_results from entity_relation_task
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import weakref
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
@@ -267,119 +268,250 @@ def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def limit_async_func_call(max_size: int):
|
||||
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
|
||||
# Custom exception class
|
||||
class QueueFullError(Exception):
|
||||
"""Raised when the queue is full and the wait times out"""
|
||||
pass
|
||||
|
||||
def final_decro(func):
|
||||
sem = asyncio.Semaphore(max_size)
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, **kwargs):
|
||||
async with sem:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wait_func
|
||||
|
||||
return final_decro
|
||||
|
||||
|
||||
def priority_limit_async_func_call(max_size: int):
|
||||
def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||
"""
|
||||
Add restriction of maximum concurrent async calls using priority queue.
|
||||
Lower priority value means higher priority.
|
||||
Enhanced priority-limited asynchronous function call decorator
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of concurrent calls
|
||||
|
||||
max_queue_size: Maximum queue capacity to prevent memory overflow
|
||||
Returns:
|
||||
A decorator that wraps an async function with priority-based concurrency control
|
||||
Decorator function
|
||||
"""
|
||||
|
||||
def final_decro(func):
|
||||
# Create shared worker pool resources
|
||||
queue = asyncio.PriorityQueue()
|
||||
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
|
||||
tasks = set()
|
||||
lock = asyncio.Lock()
|
||||
counter = 0
|
||||
shutdown_event = asyncio.Event()
|
||||
worker_health_check_task = None
|
||||
|
||||
# Worker function that processes tasks from the queue
|
||||
# Track active future objects for cleanup
|
||||
active_futures = weakref.WeakSet()
|
||||
|
||||
# Worker function to process tasks in the queue
|
||||
async def worker():
|
||||
"""Worker that processes tasks from the priority queue"""
|
||||
while True:
|
||||
# Get task from queue (priority, count, future, args, kwargs)
|
||||
_, _, future, args, kwargs = await queue.get()
|
||||
"""Worker that processes tasks in the priority queue"""
|
||||
try:
|
||||
# Execute the function
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
# Use timeout to get tasks, allowing periodic checking of shutdown signal
|
||||
try:
|
||||
priority, count, future, args, kwargs = await asyncio.wait_for(
|
||||
queue.get(), timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout is just to check shutdown signal, continue to next iteration
|
||||
continue
|
||||
|
||||
# If future is cancelled, skip execution
|
||||
if future.cancelled():
|
||||
queue.task_done()
|
||||
continue
|
||||
|
||||
try:
|
||||
# Execute function
|
||||
result = await func(*args, **kwargs)
|
||||
# Set result to future if not already done
|
||||
# If future is not done, set the result
|
||||
if not future.done():
|
||||
future.set_result(result)
|
||||
except asyncio.CancelledError:
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
logger.debug("limit_async: Task cancelled during execution")
|
||||
except Exception as e:
|
||||
# Set exception to future if not already done
|
||||
logger.error(f"limit_async: Error in decorated function: {str(e)}")
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
# Mark task as done
|
||||
queue.task_done()
|
||||
except Exception as e:
|
||||
# Catch all exceptions in worker loop to prevent worker termination
|
||||
logger.error(f"limit_async: Critical error in worker: {str(e)}")
|
||||
await asyncio.sleep(0.1) # Prevent high CPU usage
|
||||
finally:
|
||||
logger.warning("limit_async: Worker exiting")
|
||||
|
||||
async def health_check():
|
||||
"""Periodically check worker health status and recover"""
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
await asyncio.sleep(5) # Check every 5 seconds
|
||||
|
||||
async with lock:
|
||||
# Directly remove completed tasks from the tasks set
|
||||
tasks.difference_update({t for t in tasks if t.done()})
|
||||
|
||||
# Create new workers if active tasks less than max_size for better performance
|
||||
active_tasks_count = len(tasks)
|
||||
workers_needed = max_size - active_tasks_count
|
||||
if workers_needed > 0:
|
||||
logger.info(f"limit_async: Creating {workers_needed} new workers")
|
||||
for _ in range(workers_needed):
|
||||
task = asyncio.create_task(worker())
|
||||
tasks.add(task)
|
||||
task.add_done_callback(tasks.discard)
|
||||
except Exception as e:
|
||||
logger.error(f"limit_async: Error in health check: {str(e)}")
|
||||
finally:
|
||||
logger.warning("limit_async: Health check task exiting")
|
||||
|
||||
# Ensure worker tasks are started
|
||||
async def ensure_workers():
|
||||
"""Ensure worker tasks are started"""
|
||||
nonlocal tasks
|
||||
async with lock:
|
||||
if not tasks:
|
||||
# Start worker tasks
|
||||
for _ in range(max_size):
|
||||
"""Ensure worker tasks and health check are started"""
|
||||
nonlocal tasks, worker_health_check_task
|
||||
|
||||
# Use timeout lock to prevent deadlock
|
||||
try:
|
||||
lock_acquired = False
|
||||
try:
|
||||
# Try to acquire the lock, wait up to 5 seconds
|
||||
lock_acquired = await asyncio.wait_for(lock.acquire(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("limit_async: Timeout acquiring lock in ensure_workers")
|
||||
# Even if acquiring the lock times out, continue trying to create workers
|
||||
|
||||
try:
|
||||
# Start the health check task (if not already started)
|
||||
if worker_health_check_task is None or worker_health_check_task.done():
|
||||
worker_health_check_task = asyncio.create_task(health_check())
|
||||
|
||||
# Directly remove completed tasks from the tasks set
|
||||
tasks.difference_update({t for t in tasks if t.done()})
|
||||
|
||||
# Calculate the number of active tasks
|
||||
active_tasks_count = len(tasks)
|
||||
|
||||
# If active tasks count is less than max_size, create new workers
|
||||
workers_needed = max_size - active_tasks_count
|
||||
if workers_needed > 0:
|
||||
for _ in range(workers_needed):
|
||||
task = asyncio.create_task(worker())
|
||||
tasks.add(task)
|
||||
task.add_done_callback(tasks.discard)
|
||||
finally:
|
||||
# Ensure the lock is released
|
||||
if lock_acquired:
|
||||
lock.release()
|
||||
except Exception as e:
|
||||
logger.error(f"limit_async: Error in ensure_workers: {str(e)}")
|
||||
# Even if an exception occurs, try to create at least one worker
|
||||
if not any(not t.done() for t in tasks):
|
||||
task = asyncio.create_task(worker())
|
||||
tasks.add(task)
|
||||
# Remove task from set when done
|
||||
task.add_done_callback(tasks.discard)
|
||||
|
||||
async def shutdown():
|
||||
"""Gracefully shut down all workers and the queue"""
|
||||
logger.info("limit_async: Shutting down priority queue workers")
|
||||
|
||||
# Set the shutdown event
|
||||
shutdown_event.set()
|
||||
|
||||
# Cancel all active futures
|
||||
for future in list(active_futures):
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
|
||||
# Wait for the queue to empty
|
||||
try:
|
||||
await asyncio.wait_for(queue.join(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("limit_async: Timeout waiting for queue to empty during shutdown")
|
||||
|
||||
# Cancel all worker tasks
|
||||
for task in list(tasks):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Cancel the health check task
|
||||
if worker_health_check_task and not worker_health_check_task.done():
|
||||
worker_health_check_task.cancel()
|
||||
try:
|
||||
await worker_health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("limit_async: Priority queue workers shutdown complete")
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, _priority=10, _timeout=None, **kwargs):
|
||||
async def wait_func(*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs):
|
||||
"""
|
||||
Execute function with priority-based concurrency control
|
||||
|
||||
Execute the function with priority-based concurrency control
|
||||
Args:
|
||||
*args: Positional arguments to pass to the function
|
||||
priority: Priority of the call (lower value means higher priority)
|
||||
timeout: Maximum time in seconds to wait for the function to complete
|
||||
**kwargs: Keyword arguments to pass to the function
|
||||
|
||||
*args: Positional arguments passed to the function
|
||||
_priority: Call priority (lower values have higher priority)
|
||||
_timeout: Maximum time to wait for function completion (in seconds)
|
||||
_queue_timeout: Maximum time to wait for entering the queue (in seconds)
|
||||
**kwargs: Keyword arguments passed to the function
|
||||
Returns:
|
||||
Result of the function call
|
||||
|
||||
The result of the function call
|
||||
Raises:
|
||||
TimeoutError: If the function call times out
|
||||
Any exception raised by the function
|
||||
QueueFullError: If the queue is full and waiting times out
|
||||
Any exception raised by the decorated function
|
||||
"""
|
||||
# Ensure workers are started
|
||||
await ensure_workers()
|
||||
|
||||
# Create future for result
|
||||
# Create a future for the result
|
||||
future = asyncio.Future()
|
||||
active_futures.add(future)
|
||||
|
||||
nonlocal counter
|
||||
async with lock:
|
||||
current_count = counter
|
||||
counter += 1
|
||||
|
||||
# Put task in queue with priority and monotonic counter
|
||||
# Try to put the task into the queue, supporting timeout
|
||||
try:
|
||||
if _queue_timeout is not None:
|
||||
# Use timeout to wait for queue space
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
queue.put((_priority, current_count, future, args, kwargs)),
|
||||
timeout=_queue_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise QueueFullError(f"Queue full, timeout after {_queue_timeout} seconds")
|
||||
else:
|
||||
# No timeout, may wait indefinitely
|
||||
await queue.put((_priority, current_count, future, args, kwargs))
|
||||
except Exception as e:
|
||||
# Clean up the future
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
active_futures.discard(future)
|
||||
raise
|
||||
|
||||
# Wait for result with optional timeout
|
||||
try:
|
||||
# Wait for the result, optional timeout
|
||||
if _timeout is not None:
|
||||
try:
|
||||
return await asyncio.wait_for(future, _timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# Cancel future if possible
|
||||
# Cancel the future
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
raise TimeoutError(f"Task timed out after {_timeout} seconds")
|
||||
raise TimeoutError(f"limit_async: Task timed out after {_timeout} seconds")
|
||||
else:
|
||||
# Wait for result without timeout
|
||||
# Wait for the result without timeout
|
||||
return await future
|
||||
finally:
|
||||
# Clean up the future reference
|
||||
active_futures.discard(future)
|
||||
|
||||
# Add the shutdown method to the decorated function
|
||||
wait_func.shutdown = shutdown
|
||||
|
||||
return wait_func
|
||||
|
||||
|
Reference in New Issue
Block a user