Merge pull request #1483 from danielaskdd/llm-priority-support
Feat: Introduce priority scheduling to optimize parallel execution of LLM and Embedding tasks.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "1.3.5"
|
||||
__version__ = "1.3.6"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -1 +1 @@
|
||||
__api_version__ = "0163"
|
||||
__api_version__ = "0164"
|
||||
|
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
import{j as o,Y as td,O as fg,k as dg,u as ad,Z as mg,c as hg,l as gg,g as pg,S as yg,T as vg,n as bg,m as nd,o as Sg,p as Tg,$ as ud,a0 as id,a1 as cd,a2 as xg}from"./ui-vendor-CeCm8EER.js";import{d as Ag,h as Dg,r as E,u as sd,H as Ng,i as Eg,j as kf}from"./react-vendor-DEwriMA6.js";import{w as Xe,c as Ve,a2 as od,u as ql,v as Gt,a3 as rd,a4 as fd,I as us,B as Cn,D as Mg,i as zg,j as Cg,k as Og,l as jg,a5 as Rg,a6 as Ug,a7 as _g,a8 as Hg,a9 as Ll,aa as dd,ab as ss,ac as is,ad as Lg,ae as qg,af as Bg,ag as Gg,ah as Yg,ai as wg,aj as md,ak as Xg,al as Vg,am as hd,an as Qg,ao as gd,C as Kg,z as Zg,G as kg,d as En,ap as Jg,aq as Fg,ar as $g}from"./feature-graph-trMp_ED2.js";import{S as Jf,a as Ff,b as $f,c as Wf,d as ot,R as Wg}from"./feature-retrieval-DaoLh-kj.js";import{D as Pg}from"./feature-documents-DCEXq3Fi.js";import{i as cs}from"./utils-vendor-BysuhMZA.js";import"./graph-vendor-B-X5JegA.js";import"./mermaid-vendor-c8YIQY7F.js";import"./markdown-vendor-BBaHfVvE.js";(function(){const b=document.createElement("link").relList;if(b&&b.supports&&b.supports("modulepreload"))return;for(const N of document.querySelectorAll('link[rel="modulepreload"]'))d(N);new MutationObserver(N=>{for(const j of N)if(j.type==="childList")for(const H of j.addedNodes)H.tagName==="LINK"&&H.rel==="modulepreload"&&d(H)}).observe(document,{childList:!0,subtree:!0});function x(N){const j={};return N.integrity&&(j.integrity=N.integrity),N.referrerPolicy&&(j.referrerPolicy=N.referrerPolicy),N.crossOrigin==="use-credentials"?j.credentials="include":N.crossOrigin==="anonymous"?j.credentials="omit":j.credentials="same-origin",j}function d(N){if(N.ep)return;N.ep=!0;const j=x(N);fetch(N.href,j)}})();var ts={exports:{}},Mn={},as={exports:{}},ns={};/**
|
||||
import{j as o,Y as td,O as fg,k as dg,u as ad,Z as mg,c as hg,l as gg,g as pg,S as yg,T as vg,n as bg,m as nd,o as Sg,p as Tg,$ as ud,a0 as id,a1 as cd,a2 as xg}from"./ui-vendor-CeCm8EER.js";import{d as Ag,h as Dg,r as E,u as sd,H as Ng,i as Eg,j as kf}from"./react-vendor-DEwriMA6.js";import{w as Xe,c as Ve,a2 as od,u as ql,v as Gt,a3 as rd,a4 as fd,I as us,B as Cn,D as Mg,i as zg,j as Cg,k as Og,l as jg,a5 as Rg,a6 as Ug,a7 as _g,a8 as Hg,a9 as Ll,aa as dd,ab as ss,ac as is,ad as Lg,ae as qg,af as Bg,ag as Gg,ah as Yg,ai as wg,aj as md,ak as Xg,al as Vg,am as hd,an as Qg,ao as gd,C as Kg,z as Zg,G as kg,d as En,ap as Jg,aq as Fg,ar as $g}from"./feature-graph-trMp_ED2.js";import{S as Jf,a as Ff,b as $f,c as Wf,d as ot,R as Wg}from"./feature-retrieval-DaoLh-kj.js";import{D as Pg}from"./feature-documents-DMvE8vgg.js";import{i as cs}from"./utils-vendor-BysuhMZA.js";import"./graph-vendor-B-X5JegA.js";import"./mermaid-vendor-c8YIQY7F.js";import"./markdown-vendor-BBaHfVvE.js";(function(){const b=document.createElement("link").relList;if(b&&b.supports&&b.supports("modulepreload"))return;for(const N of document.querySelectorAll('link[rel="modulepreload"]'))d(N);new MutationObserver(N=>{for(const j of N)if(j.type==="childList")for(const H of j.addedNodes)H.tagName==="LINK"&&H.rel==="modulepreload"&&d(H)}).observe(document,{childList:!0,subtree:!0});function x(N){const j={};return N.integrity&&(j.integrity=N.integrity),N.referrerPolicy&&(j.referrerPolicy=N.referrerPolicy),N.crossOrigin==="use-credentials"?j.credentials="include":N.crossOrigin==="anonymous"?j.credentials="omit":j.credentials="same-origin",j}function d(N){if(N.ep)return;N.ep=!0;const j=x(N);fetch(N.href,j)}})();var ts={exports:{}},Mn={},as={exports:{}},ns={};/**
|
||||
* @license React
|
||||
* scheduler.production.js
|
||||
*
|
File diff suppressed because one or more lines are too long
6
lightrag/api/webui/index.html
generated
6
lightrag/api/webui/index.html
generated
@@ -8,7 +8,7 @@
|
||||
<link rel="icon" type="image/svg+xml" href="logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Lightrag</title>
|
||||
<script type="module" crossorigin src="/webui/assets/index-QYZi8qob.js"></script>
|
||||
<script type="module" crossorigin src="/webui/assets/index-C7DV8sI5.js"></script>
|
||||
<link rel="modulepreload" crossorigin href="/webui/assets/react-vendor-DEwriMA6.js">
|
||||
<link rel="modulepreload" crossorigin href="/webui/assets/ui-vendor-CeCm8EER.js">
|
||||
<link rel="modulepreload" crossorigin href="/webui/assets/graph-vendor-B-X5JegA.js">
|
||||
@@ -17,9 +17,9 @@
|
||||
<link rel="modulepreload" crossorigin href="/webui/assets/mermaid-vendor-c8YIQY7F.js">
|
||||
<link rel="modulepreload" crossorigin href="/webui/assets/markdown-vendor-BBaHfVvE.js">
|
||||
<link rel="modulepreload" crossorigin href="/webui/assets/feature-retrieval-DaoLh-kj.js">
|
||||
<link rel="modulepreload" crossorigin href="/webui/assets/feature-documents-DCEXq3Fi.js">
|
||||
<link rel="modulepreload" crossorigin href="/webui/assets/feature-documents-DMvE8vgg.js">
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/feature-graph-BipNuM18.css">
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/index-DsHQCgEh.css">
|
||||
<link rel="stylesheet" crossorigin href="/webui/assets/index-Dd2S8XA6.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
@@ -161,7 +161,9 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=embedding.tolist()
|
||||
|
@@ -175,7 +175,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
"""
|
||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||
"""
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
# embedding is shape (1, dim)
|
||||
embedding = np.array(embedding, dtype=np.float32)
|
||||
faiss.normalize_L2(embedding) # we do in-place normalization
|
||||
|
@@ -197,3 +197,10 @@ class JsonKVStorage(BaseKVStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping {self.namespace}: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def finalize(self):
|
||||
"""Finalize storage resources
|
||||
Persistence cache data to disk before exiting
|
||||
"""
|
||||
if self.namespace.endswith("cache"):
|
||||
await self.index_done_callback()
|
||||
|
@@ -104,7 +104,9 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
data=embedding,
|
||||
|
@@ -1032,7 +1032,9 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Queries the vector database using Atlas Vector Search."""
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
|
||||
# Convert numpy array to a list to ensure compatibility with MongoDB
|
||||
query_vector = embedding[0].tolist()
|
||||
|
@@ -124,8 +124,10 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding = await self.embedding_func([query])
|
||||
# Execute embedding outside of lock to avoid improve cocurrent
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embedding[0]
|
||||
|
||||
client = await self._get_client()
|
||||
|
@@ -644,7 +644,9 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func([query])
|
||||
embeddings = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embeddings[0]
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
# Use parameterized document IDs (None means search across all documents)
|
||||
|
@@ -124,7 +124,9 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
async def query(
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
query_vector=embedding[0],
|
||||
|
@@ -65,17 +65,17 @@ class UnifiedLock(Generic[T]):
|
||||
|
||||
async def __aenter__(self) -> "UnifiedLock[T]":
|
||||
try:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
# direct_log(
|
||||
# f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
|
||||
# enable_output=self._enable_logging,
|
||||
# )
|
||||
|
||||
# If in multiprocess mode and async lock exists, acquire it first
|
||||
if not self._is_async and self._async_lock is not None:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
# direct_log(
|
||||
# f"== Lock == Process {self._pid}: Acquiring async lock for '{self._name}'",
|
||||
# enable_output=self._enable_logging,
|
||||
# )
|
||||
await self._async_lock.acquire()
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Async lock for '{self._name}' acquired",
|
||||
@@ -112,10 +112,10 @@ class UnifiedLock(Generic[T]):
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
main_lock_released = False
|
||||
try:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
# direct_log(
|
||||
# f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
|
||||
# enable_output=self._enable_logging,
|
||||
# )
|
||||
|
||||
# Release main lock first
|
||||
if self._is_async:
|
||||
@@ -127,10 +127,10 @@ class UnifiedLock(Generic[T]):
|
||||
|
||||
# Then release async lock if in multiprocess mode
|
||||
if not self._is_async and self._async_lock is not None:
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Releasing async lock for '{self._name}'",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
# direct_log(
|
||||
# f"== Lock == Process {self._pid}: Releasing async lock for '{self._name}'",
|
||||
# enable_output=self._enable_logging,
|
||||
# )
|
||||
self._async_lock.release()
|
||||
|
||||
direct_log(
|
||||
|
@@ -390,7 +390,9 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
self, query: str, top_k: int, ids: list[str] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search from tidb vector"""
|
||||
embeddings = await self.embedding_func([query])
|
||||
embeddings = await self.embedding_func(
|
||||
[query], _priority=5
|
||||
) # higher priority for query
|
||||
embedding = embeddings[0]
|
||||
|
||||
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
||||
|
@@ -61,7 +61,7 @@ from .utils import (
|
||||
compute_mdhash_id,
|
||||
convert_response_to_json,
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
priority_limit_async_func_call,
|
||||
get_content_summary,
|
||||
clean_text,
|
||||
check_storage_env_vars,
|
||||
@@ -338,9 +338,9 @@ class LightRAG:
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init Embedding
|
||||
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
|
||||
self.embedding_func
|
||||
)
|
||||
self.embedding_func = priority_limit_async_func_call(
|
||||
self.embedding_func_max_async
|
||||
)(self.embedding_func)
|
||||
|
||||
# Initialize all storages
|
||||
self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
|
||||
@@ -426,7 +426,7 @@ class LightRAG:
|
||||
# Directly use llm_response_cache, don't create a new object
|
||||
hashing_kv = self.llm_response_cache
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
self.llm_model_func = priority_limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(
|
||||
self.llm_model_func, # type: ignore
|
||||
hashing_kv=hashing_kv,
|
||||
@@ -1024,7 +1024,8 @@ class LightRAG:
|
||||
}
|
||||
)
|
||||
|
||||
# Release semphore before entering to merge stage
|
||||
# Semphore was released here
|
||||
|
||||
if file_extraction_stage_ok:
|
||||
try:
|
||||
# Get chunk_results from entity_relation_task
|
||||
@@ -1152,9 +1153,6 @@ class LightRAG:
|
||||
try:
|
||||
chunk_results = await extract_entities(
|
||||
chunk,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
pipeline_status=pipeline_status,
|
||||
pipeline_status_lock=pipeline_status_lock,
|
||||
@@ -1445,6 +1443,9 @@ class LightRAG:
|
||||
elif param.mode == "bypass":
|
||||
# Bypass mode: directly use LLM without knowledge retrieval
|
||||
use_llm_func = param.model_func or global_config["llm_model_func"]
|
||||
# Apply higher priority (8) to entity/relation summary tasks
|
||||
use_llm_func = partial(use_llm_func, _priority=8)
|
||||
|
||||
param.stream = True if param.stream is None else param.stream
|
||||
response = await use_llm_func(
|
||||
query.strip(),
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from functools import partial
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
@@ -112,6 +113,9 @@ async def _handle_entity_relation_summary(
|
||||
If too long, use LLM to summarize.
|
||||
"""
|
||||
use_llm_func: callable = global_config["llm_model_func"]
|
||||
# Apply higher priority (8) to entity/relation summary tasks
|
||||
use_llm_func = partial(use_llm_func, _priority=8)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
llm_max_tokens = global_config["llm_model_max_token_size"]
|
||||
summary_max_tokens = global_config["summary_to_max_tokens"]
|
||||
@@ -136,7 +140,7 @@ async def _handle_entity_relation_summary(
|
||||
use_prompt = prompt_template.format(**context_base)
|
||||
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
||||
|
||||
# Use LLM function with cache
|
||||
# Use LLM function with cache (higher priority for summary generation)
|
||||
summary = await use_llm_func_with_cache(
|
||||
use_prompt,
|
||||
use_llm_func,
|
||||
@@ -504,8 +508,6 @@ async def merge_nodes_and_edges(
|
||||
# Get lock manager from shared storage
|
||||
from .kg.shared_storage import get_graph_db_lock
|
||||
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
|
||||
# Collect all nodes and edges from all chunks
|
||||
all_nodes = defaultdict(list)
|
||||
all_edges = defaultdict(list)
|
||||
@@ -526,6 +528,7 @@ async def merge_nodes_and_edges(
|
||||
|
||||
# Merge nodes and edges
|
||||
# Use graph database lock to ensure atomic merges and updates
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=True)
|
||||
async with graph_db_lock:
|
||||
async with pipeline_status_lock:
|
||||
log_message = (
|
||||
@@ -612,9 +615,6 @@ async def merge_nodes_and_edges(
|
||||
|
||||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict[str, str],
|
||||
pipeline_status: dict = None,
|
||||
pipeline_status_lock=None,
|
||||
@@ -849,12 +849,14 @@ async def kg_query(
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
if query_param.model_func:
|
||||
use_model_func = query_param.model_func
|
||||
else:
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# Handle cache
|
||||
use_model_func = (
|
||||
query_param.model_func
|
||||
if query_param.model_func
|
||||
else global_config["llm_model_func"]
|
||||
)
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
@@ -1050,9 +1052,13 @@ async def extract_keywords_only(
|
||||
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
|
||||
|
||||
# 5. Call the LLM for keyword extraction
|
||||
use_model_func = (
|
||||
param.model_func if param.model_func else global_config["llm_model_func"]
|
||||
)
|
||||
if param.model_func:
|
||||
use_model_func = param.model_func
|
||||
else:
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
|
||||
# 6. Parse out JSON from the LLM response
|
||||
@@ -1115,12 +1121,15 @@ async def mix_kg_vector_query(
|
||||
"""
|
||||
# get tokenizer
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
|
||||
if query_param.model_func:
|
||||
use_model_func = query_param.model_func
|
||||
else:
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# 1. Cache handling
|
||||
use_model_func = (
|
||||
query_param.model_func
|
||||
if query_param.model_func
|
||||
else global_config["llm_model_func"]
|
||||
)
|
||||
args_hash = compute_args_hash("mix", query, cache_type="query")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, "mix", cache_type="query"
|
||||
@@ -2006,12 +2015,14 @@ async def naive_query(
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
if query_param.model_func:
|
||||
use_model_func = query_param.model_func
|
||||
else:
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# Handle cache
|
||||
use_model_func = (
|
||||
query_param.model_func
|
||||
if query_param.model_func
|
||||
else global_config["llm_model_func"]
|
||||
)
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
@@ -2138,15 +2149,16 @@ async def kg_query_with_keywords(
|
||||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||||
Then it uses those to build context and produce a final LLM response.
|
||||
"""
|
||||
if query_param.model_func:
|
||||
use_model_func = query_param.model_func
|
||||
else:
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# ---------------------------
|
||||
# 1) Handle potential cache for query results
|
||||
# ---------------------------
|
||||
use_model_func = (
|
||||
query_param.model_func
|
||||
if query_param.model_func
|
||||
else global_config["llm_model_func"]
|
||||
)
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import weakref
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
@@ -267,17 +268,261 @@ def compute_mdhash_id(content: str, prefix: str = "") -> str:
|
||||
return prefix + md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
def limit_async_func_call(max_size: int):
|
||||
"""Add restriction of maximum concurrent async calls using asyncio.Semaphore"""
|
||||
# Custom exception class
|
||||
class QueueFullError(Exception):
|
||||
"""Raised when the queue is full and the wait times out"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def priority_limit_async_func_call(max_size: int, max_queue_size: int = 1000):
|
||||
"""
|
||||
Enhanced priority-limited asynchronous function call decorator
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of concurrent calls
|
||||
max_queue_size: Maximum queue capacity to prevent memory overflow
|
||||
Returns:
|
||||
Decorator function
|
||||
"""
|
||||
|
||||
def final_decro(func):
|
||||
sem = asyncio.Semaphore(max_size)
|
||||
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
|
||||
tasks = set()
|
||||
initialization_lock = asyncio.Lock()
|
||||
counter = 0
|
||||
shutdown_event = asyncio.Event()
|
||||
initialized = False # Global initialization flag
|
||||
worker_health_check_task = None
|
||||
|
||||
# Track active future objects for cleanup
|
||||
active_futures = weakref.WeakSet()
|
||||
|
||||
# Worker function to process tasks in the queue
|
||||
async def worker():
|
||||
"""Worker that processes tasks in the priority queue"""
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
# Use timeout to get tasks, allowing periodic checking of shutdown signal
|
||||
try:
|
||||
(
|
||||
priority,
|
||||
count,
|
||||
future,
|
||||
args,
|
||||
kwargs,
|
||||
) = await asyncio.wait_for(queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout is just to check shutdown signal, continue to next iteration
|
||||
continue
|
||||
|
||||
# If future is cancelled, skip execution
|
||||
if future.cancelled():
|
||||
queue.task_done()
|
||||
continue
|
||||
|
||||
try:
|
||||
# Execute function
|
||||
result = await func(*args, **kwargs)
|
||||
# If future is not done, set the result
|
||||
if not future.done():
|
||||
future.set_result(result)
|
||||
except asyncio.CancelledError:
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
logger.debug("limit_async: Task cancelled during execution")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"limit_async: Error in decorated function: {str(e)}"
|
||||
)
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
queue.task_done()
|
||||
except Exception as e:
|
||||
# Catch all exceptions in worker loop to prevent worker termination
|
||||
logger.error(f"limit_async: Critical error in worker: {str(e)}")
|
||||
await asyncio.sleep(0.1) # Prevent high CPU usage
|
||||
finally:
|
||||
logger.warning("limit_async: Worker exiting")
|
||||
|
||||
async def health_check():
|
||||
"""Periodically check worker health status and recover"""
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
await asyncio.sleep(5) # Check every 5 seconds
|
||||
|
||||
# No longer acquire lock, directly operate on task set
|
||||
# Use a copy of the task set to avoid concurrent modification
|
||||
current_tasks = set(tasks)
|
||||
done_tasks = {t for t in current_tasks if t.done()}
|
||||
tasks.difference_update(done_tasks)
|
||||
|
||||
# Calculate active tasks count
|
||||
active_tasks_count = len(tasks)
|
||||
workers_needed = max_size - active_tasks_count
|
||||
|
||||
if workers_needed > 0:
|
||||
logger.info(
|
||||
f"limit_async: Creating {workers_needed} new workers"
|
||||
)
|
||||
new_tasks = set()
|
||||
for _ in range(workers_needed):
|
||||
task = asyncio.create_task(worker())
|
||||
new_tasks.add(task)
|
||||
task.add_done_callback(tasks.discard)
|
||||
# Update task set in one operation
|
||||
tasks.update(new_tasks)
|
||||
except Exception as e:
|
||||
logger.error(f"limit_async: Error in health check: {str(e)}")
|
||||
finally:
|
||||
logger.warning("limit_async: Health check task exiting")
|
||||
|
||||
async def ensure_workers():
|
||||
"""Ensure worker threads and health check system are available
|
||||
|
||||
This function checks if the worker system is already initialized.
|
||||
If not, it performs a one-time initialization of all worker threads
|
||||
and starts the health check system.
|
||||
"""
|
||||
nonlocal initialized, worker_health_check_task, tasks
|
||||
|
||||
if initialized:
|
||||
return
|
||||
|
||||
async with initialization_lock:
|
||||
if initialized:
|
||||
return
|
||||
|
||||
logger.info("limit_async: Initializing worker system")
|
||||
|
||||
# Create initial worker tasks
|
||||
for _ in range(max_size):
|
||||
task = asyncio.create_task(worker())
|
||||
tasks.add(task)
|
||||
task.add_done_callback(tasks.discard)
|
||||
|
||||
# Start health check
|
||||
worker_health_check_task = asyncio.create_task(health_check())
|
||||
|
||||
initialized = True
|
||||
logger.info("limit_async: Worker system initialized")
|
||||
|
||||
async def shutdown():
|
||||
"""Gracefully shut down all workers and the queue"""
|
||||
logger.info("limit_async: Shutting down priority queue workers")
|
||||
|
||||
# Set the shutdown event
|
||||
shutdown_event.set()
|
||||
|
||||
# Cancel all active futures
|
||||
for future in list(active_futures):
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
|
||||
# Wait for the queue to empty
|
||||
try:
|
||||
await asyncio.wait_for(queue.join(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"limit_async: Timeout waiting for queue to empty during shutdown"
|
||||
)
|
||||
|
||||
# Cancel all worker tasks
|
||||
for task in list(tasks):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Cancel the health check task
|
||||
if worker_health_check_task and not worker_health_check_task.done():
|
||||
worker_health_check_task.cancel()
|
||||
try:
|
||||
await worker_health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("limit_async: Priority queue workers shutdown complete")
|
||||
|
||||
@wraps(func)
|
||||
async def wait_func(*args, **kwargs):
|
||||
async with sem:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
async def wait_func(
|
||||
*args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Execute the function with priority-based concurrency control
|
||||
Args:
|
||||
*args: Positional arguments passed to the function
|
||||
_priority: Call priority (lower values have higher priority)
|
||||
_timeout: Maximum time to wait for function completion (in seconds)
|
||||
_queue_timeout: Maximum time to wait for entering the queue (in seconds)
|
||||
**kwargs: Keyword arguments passed to the function
|
||||
Returns:
|
||||
The result of the function call
|
||||
Raises:
|
||||
TimeoutError: If the function call times out
|
||||
QueueFullError: If the queue is full and waiting times out
|
||||
Any exception raised by the decorated function
|
||||
"""
|
||||
# Ensure worker system is initialized
|
||||
await ensure_workers()
|
||||
|
||||
# Create a future for the result
|
||||
future = asyncio.Future()
|
||||
active_futures.add(future)
|
||||
|
||||
nonlocal counter
|
||||
async with initialization_lock:
|
||||
current_count = counter
|
||||
counter += 1
|
||||
|
||||
# Try to put the task into the queue, supporting timeout
|
||||
try:
|
||||
if _queue_timeout is not None:
|
||||
# Use timeout to wait for queue space
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
queue.put((_priority, current_count, future, args, kwargs)),
|
||||
timeout=_queue_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise QueueFullError(
|
||||
f"Queue full, timeout after {_queue_timeout} seconds"
|
||||
)
|
||||
else:
|
||||
# No timeout, may wait indefinitely
|
||||
await queue.put((_priority, current_count, future, args, kwargs))
|
||||
except Exception as e:
|
||||
# Clean up the future
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
active_futures.discard(future)
|
||||
raise
|
||||
|
||||
try:
|
||||
# Wait for the result, optional timeout
|
||||
if _timeout is not None:
|
||||
try:
|
||||
return await asyncio.wait_for(future, _timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# Cancel the future
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
raise TimeoutError(
|
||||
f"limit_async: Task timed out after {_timeout} seconds"
|
||||
)
|
||||
else:
|
||||
# Wait for the result without timeout
|
||||
return await future
|
||||
finally:
|
||||
# Clean up the future reference
|
||||
active_futures.discard(future)
|
||||
|
||||
# Add the shutdown method to the decorated function
|
||||
wait_func.shutdown = shutdown
|
||||
|
||||
return wait_func
|
||||
|
||||
@@ -726,46 +971,10 @@ async def handle_cache(
|
||||
if mode != "default": # handle cache for all type of query
|
||||
if not hashing_kv.global_config.get("enable_llm_cache"):
|
||||
return None, None, None, None
|
||||
|
||||
# TODO: deprecated (PostgreSQL cache not implemented yet)
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config",
|
||||
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
||||
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled: # Use embedding simularity to match cache
|
||||
current_embedding = await hashing_kv.embedding_func([prompt])
|
||||
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
mode=mode,
|
||||
use_llm_check=use_llm_check,
|
||||
llm_func=llm_model_func if use_llm_check else None,
|
||||
original_prompt=prompt,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})")
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
# if caching keyword embedding is enabled, return the quantized embedding for saving it latter
|
||||
logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
|
||||
return None, quantized, min_val, max_val
|
||||
|
||||
else: # handle cache for entity extraction
|
||||
if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
|
||||
return None, None, None, None
|
||||
|
||||
# Here is the conditions of code reaching this point:
|
||||
# 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
|
||||
# 2. Entity extract: enable_llm_cache_for_entity_extract is True
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
else:
|
||||
@@ -1340,7 +1549,7 @@ async def use_llm_func_with_cache(
|
||||
|
||||
Args:
|
||||
input_text: Input text to send to LLM
|
||||
use_llm_func: LLM function to call
|
||||
use_llm_func: LLM function with higher priority
|
||||
llm_response_cache: Cache storage instance
|
||||
max_tokens: Maximum tokens for generation
|
||||
history_messages: History messages list
|
||||
|
@@ -85,7 +85,7 @@ export default function PipelineStatusDialog({
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent
|
||||
className={cn(
|
||||
'sm:max-w-[600px] transition-all duration-200 fixed',
|
||||
'sm:max-w-[800px] transition-all duration-200 fixed',
|
||||
position === 'left' && '!left-[25%] !translate-x-[-50%] !mx-4',
|
||||
position === 'center' && '!left-1/2 !-translate-x-1/2',
|
||||
position === 'right' && '!left-[75%] !translate-x-[-50%] !mx-4'
|
||||
@@ -166,7 +166,7 @@ export default function PipelineStatusDialog({
|
||||
{/* Latest Message */}
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm font-medium">{t('documentPanel.pipelineStatus.latestMessage')}:</div>
|
||||
<div className="font-mono text-xs rounded-md bg-zinc-800 text-zinc-100 p-3">
|
||||
<div className="font-mono text-xs rounded-md bg-zinc-800 text-zinc-100 p-3 whitespace-pre-wrap break-words">
|
||||
{status?.latest_message || '-'}
|
||||
</div>
|
||||
</div>
|
||||
@@ -181,7 +181,7 @@ export default function PipelineStatusDialog({
|
||||
>
|
||||
{status?.history_messages?.length ? (
|
||||
status.history_messages.map((msg, idx) => (
|
||||
<div key={idx}>{msg}</div>
|
||||
<div key={idx} className="whitespace-pre-wrap break-words">{msg}</div>
|
||||
))
|
||||
) : '-'}
|
||||
</div>
|
||||
|
Reference in New Issue
Block a user