Add priority control for limited async decorator

This commit is contained in:
yangdx
2025-04-28 18:12:29 +08:00
parent 0c117816dc
commit 140b1b3cbb
3 changed files with 111 additions and 6 deletions

View File

@@ -61,7 +61,7 @@ from .utils import (
compute_mdhash_id, compute_mdhash_id,
convert_response_to_json, convert_response_to_json,
lazy_external_import, lazy_external_import,
limit_async_func_call, priority_limit_async_func_call,
get_content_summary, get_content_summary,
clean_text, clean_text,
check_storage_env_vars, check_storage_env_vars,
@@ -338,9 +338,9 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding # Init Embedding
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore self.embedding_func = priority_limit_async_func_call(
self.embedding_func self.embedding_func_max_async
) )(self.embedding_func)
# Initialize all storages # Initialize all storages
self.key_string_value_json_storage_cls: type[BaseKVStorage] = ( self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
@@ -426,7 +426,7 @@ class LightRAG:
# Directly use llm_response_cache, don't create a new object # Directly use llm_response_cache, don't create a new object
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)(
partial( partial(
self.llm_model_func, # type: ignore self.llm_model_func, # type: ignore
hashing_kv=hashing_kv, hashing_kv=hashing_kv,

View File

@@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
from functools import partial
import asyncio import asyncio
import traceback import traceback
@@ -112,6 +113,9 @@ async def _handle_entity_relation_summary(
If too long, use LLM to summarize. If too long, use LLM to summarize.
""" """
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"] llm_max_tokens = global_config["llm_model_max_token_size"]
summary_max_tokens = global_config["summary_to_max_tokens"] summary_max_tokens = global_config["summary_to_max_tokens"]
@@ -136,13 +140,14 @@ async def _handle_entity_relation_summary(
use_prompt = prompt_template.format(**context_base) use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}") logger.debug(f"Trigger summary: {entity_or_relation_name}")
# Use LLM function with cache # Use LLM function with cache (higher priority for summary generation)
summary = await use_llm_func_with_cache( summary = await use_llm_func_with_cache(
use_prompt, use_prompt,
use_llm_func, use_llm_func,
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
max_tokens=summary_max_tokens, max_tokens=summary_max_tokens,
cache_type="extract", cache_type="extract",
priority=5, # Higher priority for entity/relation summary
) )
return summary return summary

View File

@@ -8,6 +8,7 @@ import logging
import logging.handlers import logging.handlers
import os import os
import re import re
import time
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
@@ -284,6 +285,105 @@ def limit_async_func_call(max_size: int):
return final_decro return final_decro
def priority_limit_async_func_call(max_size: int):
"""
Add restriction of maximum concurrent async calls using priority queue.
Lower priority value means higher priority.
Args:
max_size: Maximum number of concurrent calls
Returns:
A decorator that wraps an async function with priority-based concurrency control
"""
def final_decro(func):
# Create shared worker pool resources
queue = asyncio.PriorityQueue()
tasks = set()
lock = asyncio.Lock()
# Worker function that processes tasks from the queue
async def worker():
"""Worker that processes tasks from the priority queue"""
while True:
# Get task from queue (priority, task_id, future, args, kwargs)
priority, _, future, args, kwargs = await queue.get()
try:
# Execute the function
result = await func(*args, **kwargs)
# Set result to future if not already done
if not future.done():
future.set_result(result)
except Exception as e:
# Set exception to future if not already done
if not future.done():
future.set_exception(e)
finally:
# Mark task as done
queue.task_done()
# Ensure worker tasks are started
async def ensure_workers():
"""Ensure worker tasks are started"""
nonlocal tasks
async with lock:
if not tasks:
# Start worker tasks
for _ in range(max_size):
task = asyncio.create_task(worker())
tasks.add(task)
# Remove task from set when done
task.add_done_callback(tasks.discard)
@wraps(func)
async def wait_func(*args, _priority=10, _timeout=None, **kwargs):
"""
Execute function with priority-based concurrency control
Args:
*args: Positional arguments to pass to the function
priority: Priority of the call (lower value means higher priority)
timeout: Maximum time in seconds to wait for the function to complete
**kwargs: Keyword arguments to pass to the function
Returns:
Result of the function call
Raises:
TimeoutError: If the function call times out
Any exception raised by the function
"""
# Ensure workers are started
await ensure_workers()
# Create future for result
future = asyncio.Future()
# Create unique task ID
task_id = id(args) + id(kwargs) + id(time.time())
# Put task in queue with priority
await queue.put((_priority, task_id, future, args, kwargs))
# Wait for result with optional timeout
if _timeout is not None:
try:
return await asyncio.wait_for(future, _timeout)
except asyncio.TimeoutError:
# Cancel future if possible
if not future.done():
future.cancel()
raise TimeoutError(f"Task timed out after {_timeout} seconds")
else:
# Wait for result without timeout
return await future
return wait_func
return final_decro
def wrap_embedding_func_with_attrs(**kwargs): def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes""" """Wrap a function with attributes"""