diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d89e9052..a1e1f051 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -61,7 +61,7 @@ from .utils import ( compute_mdhash_id, convert_response_to_json, lazy_external_import, - limit_async_func_call, + priority_limit_async_func_call, get_content_summary, clean_text, check_storage_env_vars, @@ -338,9 +338,9 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # Init Embedding - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore - self.embedding_func - ) + self.embedding_func = priority_limit_async_func_call( + self.embedding_func_max_async + )(self.embedding_func) # Initialize all storages self.key_string_value_json_storage_cls: type[BaseKVStorage] = ( @@ -426,7 +426,7 @@ class LightRAG: # Directly use llm_response_cache, don't create a new object hashing_kv = self.llm_response_cache - self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( + self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, # type: ignore hashing_kv=hashing_kv, diff --git a/lightrag/operate.py b/lightrag/operate.py index 14bcdae2..d10cb2c2 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,4 +1,5 @@ from __future__ import annotations +from functools import partial import asyncio import traceback @@ -112,6 +113,9 @@ async def _handle_entity_relation_summary( If too long, use LLM to summarize. """ use_llm_func: callable = global_config["llm_model_func"] + # Apply higher priority (8) to entity/relation summary tasks + use_llm_func = partial(use_llm_func, _priority=8) + tokenizer: Tokenizer = global_config["tokenizer"] llm_max_tokens = global_config["llm_model_max_token_size"] summary_max_tokens = global_config["summary_to_max_tokens"] @@ -136,13 +140,14 @@ async def _handle_entity_relation_summary( use_prompt = prompt_template.format(**context_base) logger.debug(f"Trigger summary: {entity_or_relation_name}") - # Use LLM function with cache + # Use LLM function with cache (higher priority for summary generation) summary = await use_llm_func_with_cache( use_prompt, use_llm_func, llm_response_cache=llm_response_cache, max_tokens=summary_max_tokens, cache_type="extract", + priority=5, # Higher priority for entity/relation summary ) return summary diff --git a/lightrag/utils.py b/lightrag/utils.py index b57acc37..1cbe383c 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -8,6 +8,7 @@ import logging import logging.handlers import os import re +import time from dataclasses import dataclass from functools import wraps from hashlib import md5 @@ -284,6 +285,105 @@ def limit_async_func_call(max_size: int): return final_decro +def priority_limit_async_func_call(max_size: int): + """ + Add restriction of maximum concurrent async calls using priority queue. + Lower priority value means higher priority. + + Args: + max_size: Maximum number of concurrent calls + + Returns: + A decorator that wraps an async function with priority-based concurrency control + """ + + def final_decro(func): + # Create shared worker pool resources + queue = asyncio.PriorityQueue() + tasks = set() + lock = asyncio.Lock() + + # Worker function that processes tasks from the queue + async def worker(): + """Worker that processes tasks from the priority queue""" + while True: + # Get task from queue (priority, task_id, future, args, kwargs) + priority, _, future, args, kwargs = await queue.get() + try: + # Execute the function + result = await func(*args, **kwargs) + # Set result to future if not already done + if not future.done(): + future.set_result(result) + except Exception as e: + # Set exception to future if not already done + if not future.done(): + future.set_exception(e) + finally: + # Mark task as done + queue.task_done() + + # Ensure worker tasks are started + async def ensure_workers(): + """Ensure worker tasks are started""" + nonlocal tasks + async with lock: + if not tasks: + # Start worker tasks + for _ in range(max_size): + task = asyncio.create_task(worker()) + tasks.add(task) + # Remove task from set when done + task.add_done_callback(tasks.discard) + + @wraps(func) + async def wait_func(*args, _priority=10, _timeout=None, **kwargs): + """ + Execute function with priority-based concurrency control + + Args: + *args: Positional arguments to pass to the function + priority: Priority of the call (lower value means higher priority) + timeout: Maximum time in seconds to wait for the function to complete + **kwargs: Keyword arguments to pass to the function + + Returns: + Result of the function call + + Raises: + TimeoutError: If the function call times out + Any exception raised by the function + """ + # Ensure workers are started + await ensure_workers() + + # Create future for result + future = asyncio.Future() + + # Create unique task ID + task_id = id(args) + id(kwargs) + id(time.time()) + + # Put task in queue with priority + await queue.put((_priority, task_id, future, args, kwargs)) + + # Wait for result with optional timeout + if _timeout is not None: + try: + return await asyncio.wait_for(future, _timeout) + except asyncio.TimeoutError: + # Cancel future if possible + if not future.done(): + future.cancel() + raise TimeoutError(f"Task timed out after {_timeout} seconds") + else: + # Wait for result without timeout + return await future + + return wait_func + + return final_decro + + def wrap_embedding_func_with_attrs(**kwargs): """Wrap a function with attributes"""