Fix concurrent problem on extract_entities function.

- Abandon the approach of temporarily replacing the global llm_model_func configuration
- Introduce custom_llm function with new_config for handle_cache while extracting entities
- Update handle_cache to accept custom_llm
This commit is contained in:
yangdx
2025-01-30 02:45:33 +08:00
parent 06647438b2
commit cc50ade14e
2 changed files with 24 additions and 8 deletions

View File

@@ -352,7 +352,7 @@ async def extract_entities(
input_text: str, history_messages: list[dict[str, str]] = None
) -> str:
if enable_llm_cache_for_entity_extract and llm_response_cache:
need_to_restore = False
custom_llm = None
if (
global_config["embedding_cache_config"]
and global_config["embedding_cache_config"]["enabled"]
@@ -360,8 +360,21 @@ async def extract_entities(
new_config = global_config.copy()
new_config["embedding_cache_config"] = None
new_config["enable_llm_cache"] = True
llm_response_cache.global_config = new_config
need_to_restore = True
# create a llm function with new_config for handle_cache
async def custom_llm(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
# 合并 new_config 和其他 kwargs,保证其他参数不被覆盖
merged_config = {**kwargs, **new_config}
return await use_llm_func(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
keyword_extraction=keyword_extraction,
**merged_config,
)
if history_messages:
history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text
@@ -370,10 +383,13 @@ async def extract_entities(
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
llm_response_cache,
arg_hash,
_prompt,
"default",
cache_type="default",
llm=custom_llm
)
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1

View File

@@ -491,7 +491,7 @@ def dequantize_embedding(
return (quantized * scale + min_val).astype(np.float32)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None, llm=None):
"""Generic cache handling function"""
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
return None, None, None, None
@@ -528,7 +528,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
llm_func=llm if (use_llm_check and llm is not None) else (llm_model_func if use_llm_check else None),
original_prompt=prompt if use_llm_check else None,
cache_type=cache_type,
)