Add priority control for limited async decorator
This commit is contained in:
@@ -8,6 +8,7 @@ import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
@@ -284,6 +285,105 @@ def limit_async_func_call(max_size: int):
|
||||
return final_decro
|
||||
|
||||
|
||||
def priority_limit_async_func_call(max_size: int):
|
||||
"""
|
||||
Add restriction of maximum concurrent async calls using priority queue.
|
||||
Lower priority value means higher priority.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of concurrent calls
|
||||
|
||||
Returns:
|
||||
A decorator that wraps an async function with priority-based concurrency control
|
||||
"""
|
||||
|
||||
def final_decro(func):
|
||||
# Create shared worker pool resources
|
||||
queue = asyncio.PriorityQueue()
|
||||
tasks = set()
|
||||
lock = asyncio.Lock()
|
||||
|
||||
# Worker function that processes tasks from the queue
|
||||
async def worker():
|
||||
"""Worker that processes tasks from the priority queue"""
|
||||
while True:
|
||||
# Get task from queue (priority, task_id, future, args, kwargs)
|
||||
priority, _, future, args, kwargs = await queue.get()
|
||||
try:
|
||||
# Execute the function
|
||||
result = await func(*args, **kwargs)
|
||||
# Set result to future if not already done
|
||||
if not future.done():
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
# Set exception to future if not already done
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
# Mark task as done
|
||||
queue.task_done()
|
||||
|
||||
# Ensure worker tasks are started
|
||||
async def ensure_workers():
|
||||
"""Ensure worker tasks are started"""
|
||||
nonlocal tasks
|
||||
async with lock:
|
||||
if not tasks:
|
||||
# Start worker tasks
|
||||
for _ in range(max_size):
|
||||
task = asyncio.create_task(worker())
|
||||
tasks.add(task)
|
||||
# Remove task from set when done
|
||||
task.add_done_callback(tasks.discard)
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, _priority=10, _timeout=None, **kwargs):
|
||||
"""
|
||||
Execute function with priority-based concurrency control
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to pass to the function
|
||||
priority: Priority of the call (lower value means higher priority)
|
||||
timeout: Maximum time in seconds to wait for the function to complete
|
||||
**kwargs: Keyword arguments to pass to the function
|
||||
|
||||
Returns:
|
||||
Result of the function call
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the function call times out
|
||||
Any exception raised by the function
|
||||
"""
|
||||
# Ensure workers are started
|
||||
await ensure_workers()
|
||||
|
||||
# Create future for result
|
||||
future = asyncio.Future()
|
||||
|
||||
# Create unique task ID
|
||||
task_id = id(args) + id(kwargs) + id(time.time())
|
||||
|
||||
# Put task in queue with priority
|
||||
await queue.put((_priority, task_id, future, args, kwargs))
|
||||
|
||||
# Wait for result with optional timeout
|
||||
if _timeout is not None:
|
||||
try:
|
||||
return await asyncio.wait_for(future, _timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# Cancel future if possible
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
raise TimeoutError(f"Task timed out after {_timeout} seconds")
|
||||
else:
|
||||
# Wait for result without timeout
|
||||
return await future
|
||||
|
||||
return wait_func
|
||||
|
||||
return final_decro
|
||||
|
||||
|
||||
def wrap_embedding_func_with_attrs(**kwargs):
|
||||
"""Wrap a function with attributes"""
|
||||
|
||||
|
Reference in New Issue
Block a user