Merge pull request #1483 from danielaskdd/llm-priority-support

Feat: Introduce priority scheduling to optimize parallel execution of LLM and Embedding tasks.
This commit is contained in:
Daniel.y
2025-04-29 00:14:56 +08:00
committed by GitHub
20 changed files with 363 additions and 118 deletions

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.3.5"
__version__ = "1.3.6"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -1 +1 @@
__api_version__ = "0163"
__api_version__ = "0164"

View File

@@ -1,4 +1,4 @@
import{j as o,Y as td,O as fg,k as dg,u as ad,Z as mg,c as hg,l as gg,g as pg,S as yg,T as vg,n as bg,m as nd,o as Sg,p as Tg,$ as ud,a0 as id,a1 as cd,a2 as xg}from"./ui-vendor-CeCm8EER.js";import{d as Ag,h as Dg,r as E,u as sd,H as Ng,i as Eg,j as kf}from"./react-vendor-DEwriMA6.js";import{w as Xe,c as Ve,a2 as od,u as ql,v as Gt,a3 as rd,a4 as fd,I as us,B as Cn,D as Mg,i as zg,j as Cg,k as Og,l as jg,a5 as Rg,a6 as Ug,a7 as _g,a8 as Hg,a9 as Ll,aa as dd,ab as ss,ac as is,ad as Lg,ae as qg,af as Bg,ag as Gg,ah as Yg,ai as wg,aj as md,ak as Xg,al as Vg,am as hd,an as Qg,ao as gd,C as Kg,z as Zg,G as kg,d as En,ap as Jg,aq as Fg,ar as $g}from"./feature-graph-trMp_ED2.js";import{S as Jf,a as Ff,b as $f,c as Wf,d as ot,R as Wg}from"./feature-retrieval-DaoLh-kj.js";import{D as Pg}from"./feature-documents-DCEXq3Fi.js";import{i as cs}from"./utils-vendor-BysuhMZA.js";import"./graph-vendor-B-X5JegA.js";import"./mermaid-vendor-c8YIQY7F.js";import"./markdown-vendor-BBaHfVvE.js";(function(){const b=document.createElement("link").relList;if(b&&b.supports&&b.supports("modulepreload"))return;for(const N of document.querySelectorAll('link[rel="modulepreload"]'))d(N);new MutationObserver(N=>{for(const j of N)if(j.type==="childList")for(const H of j.addedNodes)H.tagName==="LINK"&&H.rel==="modulepreload"&&d(H)}).observe(document,{childList:!0,subtree:!0});function x(N){const j={};return N.integrity&&(j.integrity=N.integrity),N.referrerPolicy&&(j.referrerPolicy=N.referrerPolicy),N.crossOrigin==="use-credentials"?j.credentials="include":N.crossOrigin==="anonymous"?j.credentials="omit":j.credentials="same-origin",j}function d(N){if(N.ep)return;N.ep=!0;const j=x(N);fetch(N.href,j)}})();var ts={exports:{}},Mn={},as={exports:{}},ns={};/**
import{j as o,Y as td,O as fg,k as dg,u as ad,Z as mg,c as hg,l as gg,g as pg,S as yg,T as vg,n as bg,m as nd,o as Sg,p as Tg,$ as ud,a0 as id,a1 as cd,a2 as xg}from"./ui-vendor-CeCm8EER.js";import{d as Ag,h as Dg,r as E,u as sd,H as Ng,i as Eg,j as kf}from"./react-vendor-DEwriMA6.js";import{w as Xe,c as Ve,a2 as od,u as ql,v as Gt,a3 as rd,a4 as fd,I as us,B as Cn,D as Mg,i as zg,j as Cg,k as Og,l as jg,a5 as Rg,a6 as Ug,a7 as _g,a8 as Hg,a9 as Ll,aa as dd,ab as ss,ac as is,ad as Lg,ae as qg,af as Bg,ag as Gg,ah as Yg,ai as wg,aj as md,ak as Xg,al as Vg,am as hd,an as Qg,ao as gd,C as Kg,z as Zg,G as kg,d as En,ap as Jg,aq as Fg,ar as $g}from"./feature-graph-trMp_ED2.js";import{S as Jf,a as Ff,b as $f,c as Wf,d as ot,R as Wg}from"./feature-retrieval-DaoLh-kj.js";import{D as Pg}from"./feature-documents-DMvE8vgg.js";import{i as cs}from"./utils-vendor-BysuhMZA.js";import"./graph-vendor-B-X5JegA.js";import"./mermaid-vendor-c8YIQY7F.js";import"./markdown-vendor-BBaHfVvE.js";(function(){const b=document.createElement("link").relList;if(b&&b.supports&&b.supports("modulepreload"))return;for(const N of document.querySelectorAll('link[rel="modulepreload"]'))d(N);new MutationObserver(N=>{for(const j of N)if(j.type==="childList")for(const H of j.addedNodes)H.tagName==="LINK"&&H.rel==="modulepreload"&&d(H)}).observe(document,{childList:!0,subtree:!0});function x(N){const j={};return N.integrity&&(j.integrity=N.integrity),N.referrerPolicy&&(j.referrerPolicy=N.referrerPolicy),N.crossOrigin==="use-credentials"?j.credentials="include":N.crossOrigin==="anonymous"?j.credentials="omit":j.credentials="same-origin",j}function d(N){if(N.ep)return;N.ep=!0;const j=x(N);fetch(N.href,j)}})();var ts={exports:{}},Mn={},as={exports:{}},ns={};/**
* @license React
* scheduler.production.js
*

File diff suppressed because one or more lines are too long

View File

@@ -8,7 +8,7 @@
<link rel="icon" type="image/svg+xml" href="logo.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Lightrag</title>
<script type="module" crossorigin src="/webui/assets/index-QYZi8qob.js"></script>
<script type="module" crossorigin src="/webui/assets/index-C7DV8sI5.js"></script>
<link rel="modulepreload" crossorigin href="/webui/assets/react-vendor-DEwriMA6.js">
<link rel="modulepreload" crossorigin href="/webui/assets/ui-vendor-CeCm8EER.js">
<link rel="modulepreload" crossorigin href="/webui/assets/graph-vendor-B-X5JegA.js">
@@ -17,9 +17,9 @@
<link rel="modulepreload" crossorigin href="/webui/assets/mermaid-vendor-c8YIQY7F.js">
<link rel="modulepreload" crossorigin href="/webui/assets/markdown-vendor-BBaHfVvE.js">
<link rel="modulepreload" crossorigin href="/webui/assets/feature-retrieval-DaoLh-kj.js">
<link rel="modulepreload" crossorigin href="/webui/assets/feature-documents-DCEXq3Fi.js">
<link rel="modulepreload" crossorigin href="/webui/assets/feature-documents-DMvE8vgg.js">
<link rel="stylesheet" crossorigin href="/webui/assets/feature-graph-BipNuM18.css">
<link rel="stylesheet" crossorigin href="/webui/assets/index-DsHQCgEh.css">
<link rel="stylesheet" crossorigin href="/webui/assets/index-Dd2S8XA6.css">
</head>
<body>
<div id="root"></div>

View File

@@ -161,7 +161,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
try:
embedding = await self.embedding_func([query])
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
results = self._collection.query(
query_embeddings=embedding.tolist()

View File

@@ -175,7 +175,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
"""
Search by a textual query; returns top_k results with their metadata + similarity distance.
"""
embedding = await self.embedding_func([query])
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# embedding is shape (1, dim)
embedding = np.array(embedding, dtype=np.float32)
faiss.normalize_L2(embedding) # we do in-place normalization

View File

@@ -197,3 +197,10 @@ class JsonKVStorage(BaseKVStorage):
except Exception as e:
logger.error(f"Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
async def finalize(self):
"""Finalize storage resources
Persistence cache data to disk before exiting
"""
if self.namespace.endswith("cache"):
await self.index_done_callback()

View File

@@ -104,7 +104,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query])
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
results = self._client.search(
collection_name=self.namespace,
data=embedding,

View File

@@ -1032,7 +1032,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search."""
# Generate the embedding
embedding = await self.embedding_func([query])
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
# Convert numpy array to a list to ensure compatibility with MongoDB
query_vector = embedding[0].tolist()

View File

@@ -124,8 +124,10 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
# Execute embedding outside of lock to avoid long lock times
embedding = await self.embedding_func([query])
# Execute embedding outside of lock to avoid improve cocurrent
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embedding[0]
client = await self._get_client()

View File

@@ -644,7 +644,9 @@ class PGVectorStorage(BaseVectorStorage):
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query])
embeddings = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding))
# Use parameterized document IDs (None means search across all documents)

View File

@@ -124,7 +124,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query])
embedding = await self.embedding_func(
[query], _priority=5
) # higher priority for query
results = self._client.search(
collection_name=self.namespace,
query_vector=embedding[0],

View File

@@ -65,17 +65,17 @@ class UnifiedLock(Generic[T]):
async def __aenter__(self) -> "UnifiedLock[T]":
try:
direct_log(
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
enable_output=self._enable_logging,
)
# direct_log(
# f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
# enable_output=self._enable_logging,
# )
# If in multiprocess mode and async lock exists, acquire it first
if not self._is_async and self._async_lock is not None:
direct_log(
f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'",
enable_output=self._enable_logging,
)
# direct_log(
# f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'",
# enable_output=self._enable_logging,
# )
await self._async_lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired",
@@ -112,10 +112,10 @@ class UnifiedLock(Generic[T]):
async def __aexit__(self, exc_type, exc_val, exc_tb):
main_lock_released = False
try:
direct_log(
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
enable_output=self._enable_logging,
)
# direct_log(
# f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
# enable_output=self._enable_logging,
# )
# Release main lock first
if self._is_async:
@@ -127,10 +127,10 @@ class UnifiedLock(Generic[T]):
# Then release async lock if in multiprocess mode
if not self._is_async and self._async_lock is not None:
direct_log(
f"== Lock == Process {self._pid}: Releasing async lock for '{self._name}'",
enable_output=self._enable_logging,
)
# direct_log(
# f"== Lock == Process {self._pid}: Releasing async lock for '{self._name}'",
# enable_output=self._enable_logging,
# )
self._async_lock.release()
direct_log(

View File

@@ -390,7 +390,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""Search from tidb vector"""
embeddings = await self.embedding_func([query])
embeddings = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embeddings[0]
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"

View File

@@ -61,7 +61,7 @@ from .utils import (
compute_mdhash_id,
convert_response_to_json,
lazy_external_import,
limit_async_func_call,
priority_limit_async_func_call,
get_content_summary,
clean_text,
check_storage_env_vars,
@@ -338,9 +338,9 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
self.embedding_func
)
self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async
)(self.embedding_func)
# Initialize all storages
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
@@ -426,7 +426,7 @@ class LightRAG:
# Directly use llm_response_cache, don't create a new object
hashing_kv = self.llm_response_cache
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func, # type: ignore
hashing_kv=hashing_kv,
@@ -1024,7 +1024,8 @@ class LightRAG:
}
)
# Release semphore before entering to merge stage
# Semphore was released here
if file_extraction_stage_ok:
try:
# Get chunk_results from entity_relation_task
@@ -1152,9 +1153,6 @@ class LightRAG:
try:
chunk_results = await extract_entities(
chunk,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
pipeline_status=pipeline_status,
pipeline_status_lock=pipeline_status_lock,
@@ -1445,6 +1443,9 @@ class LightRAG:
elif param.mode == "bypass":
# Bypass mode: directly use LLM without knowledge retrieval
use_llm_func = param.model_func or global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
param.stream = True if param.stream is None else param.stream
response = await use_llm_func(
query.strip(),

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from functools import partial
import asyncio
import traceback
@@ -112,6 +113,9 @@ async def _handle_entity_relation_summary(
If too long, use LLM to summarize.
"""
use_llm_func: callable = global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
tokenizer: Tokenizer = global_config["tokenizer"]
llm_max_tokens = global_config["llm_model_max_token_size"]
summary_max_tokens = global_config["summary_to_max_tokens"]
@@ -136,7 +140,7 @@ async def _handle_entity_relation_summary(
use_prompt = prompt_template.format(**context_base)
logger.debug(f"Trigger summary: {entity_or_relation_name}")
# Use LLM function with cache
# Use LLM function with cache (higher priority for summary generation)
summary = await use_llm_func_with_cache(
use_prompt,
use_llm_func,
@@ -504,8 +508,6 @@ async def merge_nodes_and_edges(
# Get lock manager from shared storage
from .kg.shared_storage import get_graph_db_lock
graph_db_lock = get_graph_db_lock(enable_logging=False)
# Collect all nodes and edges from all chunks
all_nodes = defaultdict(list)
all_edges = defaultdict(list)
@@ -526,6 +528,7 @@ async def merge_nodes_and_edges(
# Merge nodes and edges
# Use graph database lock to ensure atomic merges and updates
graph_db_lock = get_graph_db_lock(enable_logging=True)
async with graph_db_lock:
async with pipeline_status_lock:
log_message = (
@@ -612,9 +615,6 @@ async def merge_nodes_and_edges(
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict[str, str],
pipeline_status: dict = None,
pipeline_status_lock=None,
@@ -849,12 +849,14 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
if query_param.model_func:
use_model_func = query_param.model_func
else:
use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -1050,9 +1052,13 @@ async def extract_keywords_only(
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
use_model_func = (
param.model_func if param.model_func else global_config["llm_model_func"]
)
if param.model_func:
use_model_func = param.model_func
else:
use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response
@@ -1115,12 +1121,15 @@ async def mix_kg_vector_query(
"""
# get tokenizer
tokenizer: Tokenizer = global_config["tokenizer"]
if query_param.model_func:
use_model_func = query_param.model_func
else:
use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
# 1. Cache handling
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query"
@@ -2006,12 +2015,14 @@ async def naive_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
if query_param.model_func:
use_model_func = query_param.model_func
else:
use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -2138,15 +2149,16 @@ async def kg_query_with_keywords(
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
Then it uses those to build context and produce a final LLM response.
"""
if query_param.model_func:
use_model_func = query_param.model_func
else:
use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
# ---------------------------
# 1) Handle potential cache for query results
# ---------------------------
use_model_func = (
query_param.model_func
if query_param.model_func
else global_config["llm_model_func"]
)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
import weakref
import asyncio
import html
@@ -267,17 +268,261 @@ def compute_mdhash_id(content: str, prefix: str = "") -> str:
return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int):
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
# Custom exception class
class QueueFullError(Exception):
"""Raised when the queue is full and the wait times out"""
pass
def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
"""
Enhanced priority-limited asynchronous function call decorator
Args:
max_size: Maximum number of concurrent calls
max_queue_size: Maximum queue capacity to prevent memory overflow
Returns:
Decorator function
"""
def final_decro(func):
sem = asyncio.Semaphore(max_size)
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
tasks = set()
initialization_lock = asyncio.Lock()
counter = 0
shutdown_event = asyncio.Event()
initialized = False # Global initialization flag
worker_health_check_task = None
# Track active future objects for cleanup
active_futures = weakref.WeakSet()
# Worker function to process tasks in the queue
async def worker():
"""Worker that processes tasks in the priority queue"""
try:
while not shutdown_event.is_set():
try:
# Use timeout to get tasks, allowing periodic checking of shutdown signal
try:
(
priority,
count,
future,
args,
kwargs,
) = await asyncio.wait_for(queue.get(), timeout=1.0)
except asyncio.TimeoutError:
# Timeout is just to check shutdown signal, continue to next iteration
continue
# If future is cancelled, skip execution
if future.cancelled():
queue.task_done()
continue
try:
# Execute function
result = await func(*args, **kwargs)
# If future is not done, set the result
if not future.done():
future.set_result(result)
except asyncio.CancelledError:
if not future.done():
future.cancel()
logger.debug("limit_async: Task cancelled during execution")
except Exception as e:
logger.error(
f"limit_async: Error in decorated function: {str(e)}"
)
if not future.done():
future.set_exception(e)
finally:
queue.task_done()
except Exception as e:
# Catch all exceptions in worker loop to prevent worker termination
logger.error(f"limit_async: Critical error in worker: {str(e)}")
await asyncio.sleep(0.1) # Prevent high CPU usage
finally:
logger.warning("limit_async: Worker exiting")
async def health_check():
"""Periodically check worker health status and recover"""
try:
while not shutdown_event.is_set():
await asyncio.sleep(5) # Check every 5 seconds
# No longer acquire lock, directly operate on task set
# Use a copy of the task set to avoid concurrent modification
current_tasks = set(tasks)
done_tasks = {t for t in current_tasks if t.done()}
tasks.difference_update(done_tasks)
# Calculate active tasks count
active_tasks_count = len(tasks)
workers_needed = max_size - active_tasks_count
if workers_needed > 0:
logger.info(
f"limit_async: Creating {workers_needed} new workers"
)
new_tasks = set()
for _ in range(workers_needed):
task = asyncio.create_task(worker())
new_tasks.add(task)
task.add_done_callback(tasks.discard)
# Update task set in one operation
tasks.update(new_tasks)
except Exception as e:
logger.error(f"limit_async: Error in health check: {str(e)}")
finally:
logger.warning("limit_async: Health check task exiting")
async def ensure_workers():
"""Ensure worker threads and health check system are available
This function checks if the worker system is already initialized.
If not, it performs a one-time initialization of all worker threads
and starts the health check system.
"""
nonlocal initialized, worker_health_check_task, tasks
if initialized:
return
async with initialization_lock:
if initialized:
return
logger.info("limit_async: Initializing worker system")
# Create initial worker tasks
for _ in range(max_size):
task = asyncio.create_task(worker())
tasks.add(task)
task.add_done_callback(tasks.discard)
# Start health check
worker_health_check_task = asyncio.create_task(health_check())
initialized = True
logger.info("limit_async: Worker system initialized")
async def shutdown():
"""Gracefully shut down all workers and the queue"""
logger.info("limit_async: Shutting down priority queue workers")
# Set the shutdown event
shutdown_event.set()
# Cancel all active futures
for future in list(active_futures):
if not future.done():
future.cancel()
# Wait for the queue to empty
try:
await asyncio.wait_for(queue.join(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning(
"limit_async: Timeout waiting for queue to empty during shutdown"
)
# Cancel all worker tasks
for task in list(tasks):
if not task.done():
task.cancel()
# Wait for all tasks to complete
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
# Cancel the health check task
if worker_health_check_task and not worker_health_check_task.done():
worker_health_check_task.cancel()
try:
await worker_health_check_task
except asyncio.CancelledError:
pass
logger.info("limit_async: Priority queue workers shutdown complete")
@wraps(func)
async def wait_func(*args, **kwargs):
async with sem:
result = await func(*args, **kwargs)
return result
async def wait_func(
*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs
):
"""
Execute the function with priority-based concurrency control
Args:
*args: Positional arguments passed to the function
_priority: Call priority (lower values have higher priority)
_timeout: Maximum time to wait for function completion (in seconds)
_queue_timeout: Maximum time to wait for entering the queue (in seconds)
**kwargs: Keyword arguments passed to the function
Returns:
The result of the function call
Raises:
TimeoutError: If the function call times out
QueueFullError: If the queue is full and waiting times out
Any exception raised by the decorated function
"""
# Ensure worker system is initialized
await ensure_workers()
# Create a future for the result
future = asyncio.Future()
active_futures.add(future)
nonlocal counter
async with initialization_lock:
current_count = counter
counter += 1
# Try to put the task into the queue, supporting timeout
try:
if _queue_timeout is not None:
# Use timeout to wait for queue space
try:
await asyncio.wait_for(
queue.put((_priority, current_count, future, args, kwargs)),
timeout=_queue_timeout,
)
except asyncio.TimeoutError:
raise QueueFullError(
f"Queue full, timeout after {_queue_timeout} seconds"
)
else:
# No timeout, may wait indefinitely
await queue.put((_priority, current_count, future, args, kwargs))
except Exception as e:
# Clean up the future
if not future.done():
future.set_exception(e)
active_futures.discard(future)
raise
try:
# Wait for the result, optional timeout
if _timeout is not None:
try:
return await asyncio.wait_for(future, _timeout)
except asyncio.TimeoutError:
# Cancel the future
if not future.done():
future.cancel()
raise TimeoutError(
f"limit_async: Task timed out after {_timeout} seconds"
)
else:
# Wait for the result without timeout
return await future
finally:
# Clean up the future reference
active_futures.discard(future)
# Add the shutdown method to the decorated function
wait_func.shutdown = shutdown
return wait_func
@@ -726,46 +971,10 @@ async def handle_cache(
if mode != "default": # handle cache for all type of query
if not hashing_kv.global_config.get("enable_llm_cache"):
return None, None, None, None
# TODO: deprecated (PostgreSQL cache not implemented yet)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config",
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
use_llm_check = embedding_cache_config.get("use_llm_check", False)
quantized = min_val = max_val = None
if is_embedding_cache_enabled: # Use embedding simularity to match cache
current_embedding = await hashing_kv.embedding_func([prompt])
llm_model_func = hashing_kv.global_config.get("llm_model_func")
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt,
cache_type=cache_type,
)
if best_cached_response is not None:
logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})")
return best_cached_response, None, None, None
else:
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
return None, quantized, min_val, max_val
else: # handle cache for entity extraction
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
return None, None, None, None
# Here is the conditions of code reaching this point:
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else:
@@ -1340,7 +1549,7 @@ async def use_llm_func_with_cache(
Args:
input_text: Input text to send to LLM
use_llm_func: LLM function to call
use_llm_func: LLM function with higher priority
llm_response_cache: Cache storage instance
max_tokens: Maximum tokens for generation
history_messages: History messages list

View File

@@ -85,7 +85,7 @@ export default function PipelineStatusDialog({
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent
className={cn(
'sm:max-w-[600px] transition-all duration-200 fixed',
'sm:max-w-[800px] transition-all duration-200 fixed',
position === 'left' && '!left-[25%] !translate-x-[-50%] !mx-4',
position === 'center' && '!left-1/2 !-translate-x-1/2',
position === 'right' && '!left-[75%] !translate-x-[-50%] !mx-4'
@@ -166,7 +166,7 @@ export default function PipelineStatusDialog({
{/* Latest Message */}
<div className="space-y-2">
<div className="text-sm font-medium">{t('documentPanel.pipelineStatus.latestMessage')}:</div>
<div className="font-mono text-xs rounded-md bg-zinc-800 text-zinc-100 p-3">
<div className="font-mono text-xs rounded-md bg-zinc-800 text-zinc-100 p-3 whitespace-pre-wrap break-words">
{status?.latest_message || '-'}
</div>
</div>
@@ -181,7 +181,7 @@ export default function PipelineStatusDialog({
>
{status?.history_messages?.length ? (
status.history_messages.map((msg, idx) => (
<div key={idx}>{msg}</div>
<div key={idx} className="whitespace-pre-wrap break-words">{msg}</div>
))
) : '-'}
</div>