Add priority control for limited async decorator

This commit is contained in:
yangdx
2025-04-28 18:12:29 +08:00
parent 0c117816dc
commit 140b1b3cbb
3 changed files with 111 additions and 6 deletions

View File

@@ -8,6 +8,7 @@ import logging
import logging.handlers
import os
import re
import time
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
@@ -284,6 +285,105 @@ def limit_async_func_call(max_size: int):
return final_decro
def priority_limit_async_func_call(max_size: int):
"""
Add restriction of maximum concurrent async calls using priority queue.
Lower priority value means higher priority.
Args:
max_size: Maximum number of concurrent calls
Returns:
A decorator that wraps an async function with priority-based concurrency control
"""
def final_decro(func):
# Create shared worker pool resources
queue = asyncio.PriorityQueue()
tasks = set()
lock = asyncio.Lock()
# Worker function that processes tasks from the queue
async def worker():
"""Worker that processes tasks from the priority queue"""
while True:
# Get task from queue (priority, task_id, future, args, kwargs)
priority, _, future, args, kwargs = await queue.get()
try:
# Execute the function
result = await func(*args, **kwargs)
# Set result to future if not already done
if not future.done():
future.set_result(result)
except Exception as e:
# Set exception to future if not already done
if not future.done():
future.set_exception(e)
finally:
# Mark task as done
queue.task_done()
# Ensure worker tasks are started
async def ensure_workers():
"""Ensure worker tasks are started"""
nonlocal tasks
async with lock:
if not tasks:
# Start worker tasks
for _ in range(max_size):
task = asyncio.create_task(worker())
tasks.add(task)
# Remove task from set when done
task.add_done_callback(tasks.discard)
@wraps(func)
async def wait_func(*args, _priority=10, _timeout=None, **kwargs):
"""
Execute function with priority-based concurrency control
Args:
*args: Positional arguments to pass to the function
priority: Priority of the call (lower value means higher priority)
timeout: Maximum time in seconds to wait for the function to complete
**kwargs: Keyword arguments to pass to the function
Returns:
Result of the function call
Raises:
TimeoutError: If the function call times out
Any exception raised by the function
"""
# Ensure workers are started
await ensure_workers()
# Create future for result
future = asyncio.Future()
# Create unique task ID
task_id = id(args) + id(kwargs) + id(time.time())
# Put task in queue with priority
await queue.put((_priority, task_id, future, args, kwargs))
# Wait for result with optional timeout
if _timeout is not None:
try:
return await asyncio.wait_for(future, _timeout)
except asyncio.TimeoutError:
# Cancel future if possible
if not future.done():
future.cancel()
raise TimeoutError(f"Task timed out after {_timeout} seconds")
else:
# Wait for result without timeout
return await future
return wait_func
return final_decro
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""