feat(cache): 增加 LLM 相似性检查功能并优化缓存机制
- 在 embedding 缓存配置中添加 use_llm_check 参数 - 实现 LLM 相似性检查逻辑,作为缓存命中的二次验证- 优化 naive 模式的缓存处理流程 - 调整缓存数据结构,移除不必要的 model 字段
This commit is contained in:
@@ -17,6 +17,10 @@ from .utils import (
|
||||
split_string_by_multi_markers,
|
||||
truncate_list_by_token_size,
|
||||
process_combine_contexts,
|
||||
compute_args_hash,
|
||||
handle_cache,
|
||||
save_to_cache,
|
||||
CacheData,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -452,8 +456,17 @@ async def kg_query(
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
context = None
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
||||
examples = "\n".join(
|
||||
@@ -471,12 +484,9 @@ async def kg_query(
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# LLM generate keywords
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
||||
result = await use_model_func(
|
||||
kw_prompt, keyword_extraction=True, mode=query_param.mode
|
||||
)
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
logger.info("kw_prompt result:")
|
||||
print(result)
|
||||
try:
|
||||
@@ -537,7 +547,6 @@ async def kg_query(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
mode=query_param.mode,
|
||||
)
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
response = (
|
||||
@@ -550,6 +559,20 @@ async def kg_query(
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response,
|
||||
prompt=query,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=query_param.mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -1013,8 +1036,17 @@ async def naive_query(
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
):
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
||||
if not len(results):
|
||||
return PROMPTS["fail_response"]
|
||||
@@ -1039,7 +1071,6 @@ async def naive_query(
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
mode=query_param.mode,
|
||||
)
|
||||
|
||||
if len(response) > len(sys_prompt):
|
||||
@@ -1054,4 +1085,18 @@ async def naive_query(
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response,
|
||||
prompt=query,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=query_param.mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
Reference in New Issue
Block a user