Add priority control for limited async decorator

This commit is contained in:
yangdx
2025-04-28 18:12:29 +08:00
parent 0c117816dc
commit 140b1b3cbb
3 changed files with 111 additions and 6 deletions

View File

@@ -61,7 +61,7 @@ from .utils import (
compute_mdhash_id,
convert_response_to_json,
lazy_external_import,
limit_async_func_call,
priority_limit_async_func_call,
get_content_summary,
clean_text,
check_storage_env_vars,
@@ -338,9 +338,9 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async
)(self.embedding_func)
# Initialize all storages
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
@@ -426,7 +426,7 @@ class LightRAG:
# Directly use llm_response_cache, don't create a new object
hashing_kv = self.llm_response_cache
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from functools import partial
import asyncio
import traceback
@@ -112,6 +113,9 @@ async def _handle_entity_relation_summary(
If too long, use LLM to summarize.
"""
use_llm_func: callable = global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"]
summary_max_tokens = global_config["summary_to_max_tokens"]
@@ -136,13 +140,14 @@ async def _handle_entity_relation_summary(
use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
# Use LLM function with cache
# Use LLM function with cache (higher priority for summary generation)
summary = await use_llm_func_with_cache(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
max_tokens=summary_max_tokens,
cache_type="extract",
priority=5, # Higher priority for entity/relation summary
)
return summary

View File

@@ -8,6 +8,7 @@ import logging
import logging.handlers
import os
import re
import time
from dataclasses import dataclass
from functools import wraps
from hashlib import md5
@@ -284,6 +285,105 @@ def limit_async_func_call(max_size: int):
return final_decro
def priority_limit_async_func_call(max_size: int):
"""
Add restriction of maximum concurrent async calls using priority queue.
Lower priority value means higher priority.
Args:
max_size: Maximum number of concurrent calls
Returns:
A decorator that wraps an async function with priority-based concurrency control
"""
def final_decro(func):
# Create shared worker pool resources
queue = asyncio.PriorityQueue()
tasks = set()
lock = asyncio.Lock()
# Worker function that processes tasks from the queue
async def worker():
"""Worker that processes tasks from the priority queue"""
while True:
# Get task from queue (priority, task_id, future, args, kwargs)
priority, _, future, args, kwargs = await queue.get()
try:
# Execute the function
result = await func(*args, **kwargs)
# Set result to future if not already done
if not future.done():
future.set_result(result)
except Exception as e:
# Set exception to future if not already done
if not future.done():
future.set_exception(e)
finally:
# Mark task as done
queue.task_done()
# Ensure worker tasks are started
async def ensure_workers():
"""Ensure worker tasks are started"""
nonlocal tasks
async with lock:
if not tasks:
# Start worker tasks
for _ in range(max_size):
task = asyncio.create_task(worker())
tasks.add(task)
# Remove task from set when done
task.add_done_callback(tasks.discard)
@wraps(func)
async def wait_func(*args, _priority=10, _timeout=None, **kwargs):
"""
Execute function with priority-based concurrency control
Args:
*args: Positional arguments to pass to the function
priority: Priority of the call (lower value means higher priority)
timeout: Maximum time in seconds to wait for the function to complete
**kwargs: Keyword arguments to pass to the function
Returns:
Result of the function call
Raises:
TimeoutError: If the function call times out
Any exception raised by the function
"""
# Ensure workers are started
await ensure_workers()
# Create future for result
future = asyncio.Future()
# Create unique task ID
task_id = id(args) + id(kwargs) + id(time.time())
# Put task in queue with priority
await queue.put((_priority, task_id, future, args, kwargs))
# Wait for result with optional timeout
if _timeout is not None:
try:
return await asyncio.wait_for(future, _timeout)
except asyncio.TimeoutError:
# Cancel future if possible
if not future.done():
future.cancel()
raise TimeoutError(f"Task timed out after {_timeout} seconds")
else:
# Wait for result without timeout
return await future
return wait_func
return final_decro
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""