Add priority control for limited async decorator
This commit is contained in:
@@ -61,7 +61,7 @@ from .utils import (
|
||||
compute_mdhash_id,
|
||||
convert_response_to_json,
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
priority_limit_async_func_call,
|
||||
get_content_summary,
|
||||
clean_text,
|
||||
check_storage_env_vars,
|
||||
@@ -338,9 +338,9 @@ class LightRAG:
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init Embedding
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||
self.embedding_func
|
||||
)
|
||||
self.embedding_func = priority_limit_async_func_call(
|
||||
self.embedding_func_max_async
|
||||
)(self.embedding_func)
|
||||
|
||||
# Initialize all storages
|
||||
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||
@@ -426,7 +426,7 @@ class LightRAG:
|
||||
# Directly use llm_response_cache, don't create a new object
|
||||
hashing_kv = self.llm_response_cache
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
self.llm_model_func, # type: ignore
|
||||
hashing_kv=hashing_kv,
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from functools import partial
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
@@ -112,6 +113,9 @@ async def _handle_entity_relation_summary(
|
||||
If too long, use LLM to summarize.
|
||||
"""
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
# Apply higher priority (8) to entity/relation summary tasks
|
||||
use_llm_func = partial(use_llm_func, _priority=8)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||||
summary_max_tokens = global_config["summary_to_max_tokens"]
|
||||
@@ -136,13 +140,14 @@ async def _handle_entity_relation_summary(
|
||||
use_prompt = prompt_template.format(**context_base)
|
||||
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
||||
|
||||
# Use LLM function with cache
|
||||
# Use LLM function with cache (higher priority for summary generation)
|
||||
summary = await use_llm_func_with_cache(
|
||||
use_prompt,
|
||||
use_llm_func,
|
||||
llm_response_cache=llm_response_cache,
|
||||
max_tokens=summary_max_tokens,
|
||||
cache_type="extract",
|
||||
priority=5, # Higher priority for entity/relation summary
|
||||
)
|
||||
return summary
|
||||
|
||||
|
@@ -8,6 +8,7 @@ import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from hashlib import md5
|
||||
@@ -284,6 +285,105 @@ def limit_async_func_call(max_size: int):
|
||||
return final_decro
|
||||
|
||||
|
||||
def priority_limit_async_func_call(max_size: int):
|
||||
"""
|
||||
Add restriction of maximum concurrent async calls using priority queue.
|
||||
Lower priority value means higher priority.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of concurrent calls
|
||||
|
||||
Returns:
|
||||
A decorator that wraps an async function with priority-based concurrency control
|
||||
"""
|
||||
|
||||
def final_decro(func):
|
||||
# Create shared worker pool resources
|
||||
queue = asyncio.PriorityQueue()
|
||||
tasks = set()
|
||||
lock = asyncio.Lock()
|
||||
|
||||
# Worker function that processes tasks from the queue
|
||||
async def worker():
|
||||
"""Worker that processes tasks from the priority queue"""
|
||||
while True:
|
||||
# Get task from queue (priority, task_id, future, args, kwargs)
|
||||
priority, _, future, args, kwargs = await queue.get()
|
||||
try:
|
||||
# Execute the function
|
||||
result = await func(*args, **kwargs)
|
||||
# Set result to future if not already done
|
||||
if not future.done():
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
# Set exception to future if not already done
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
# Mark task as done
|
||||
queue.task_done()
|
||||
|
||||
# Ensure worker tasks are started
|
||||
async def ensure_workers():
|
||||
"""Ensure worker tasks are started"""
|
||||
nonlocal tasks
|
||||
async with lock:
|
||||
if not tasks:
|
||||
# Start worker tasks
|
||||
for _ in range(max_size):
|
||||
task = asyncio.create_task(worker())
|
||||
tasks.add(task)
|
||||
# Remove task from set when done
|
||||
task.add_done_callback(tasks.discard)
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, _priority=10, _timeout=None, **kwargs):
|
||||
"""
|
||||
Execute function with priority-based concurrency control
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to pass to the function
|
||||
priority: Priority of the call (lower value means higher priority)
|
||||
timeout: Maximum time in seconds to wait for the function to complete
|
||||
**kwargs: Keyword arguments to pass to the function
|
||||
|
||||
Returns:
|
||||
Result of the function call
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the function call times out
|
||||
Any exception raised by the function
|
||||
"""
|
||||
# Ensure workers are started
|
||||
await ensure_workers()
|
||||
|
||||
# Create future for result
|
||||
future = asyncio.Future()
|
||||
|
||||
# Create unique task ID
|
||||
task_id = id(args) + id(kwargs) + id(time.time())
|
||||
|
||||
# Put task in queue with priority
|
||||
await queue.put((_priority, task_id, future, args, kwargs))
|
||||
|
||||
# Wait for result with optional timeout
|
||||
if _timeout is not None:
|
||||
try:
|
||||
return await asyncio.wait_for(future, _timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# Cancel future if possible
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
raise TimeoutError(f"Task timed out after {_timeout} seconds")
|
||||
else:
|
||||
# Wait for result without timeout
|
||||
return await future
|
||||
|
||||
return wait_func
|
||||
|
||||
return final_decro
|
||||
|
||||
|
||||
def wrap_embedding_func_with_attrs(**kwargs):
|
||||
"""Wrap a function with attributes"""
|
||||
|
||||
|
Reference in New Issue
Block a user