重构缓存处理逻辑

- 提取通用缓存处理逻辑到新函数 handle_cache 和 save_to_cache
- 使用 CacheData 类统一缓存数据结构
- 优化嵌入式缓存和常规缓存的处理流程
- 添加模式参数以支持不同查询模式的缓存策略
- 重构 get_best_cached_response 函数,提高缓存查询效率
This commit is contained in:
yuanxiaobin
2024-12-06 14:29:16 +08:00
parent 7c4bbe2474
commit 584258078f
3 changed files with 277 additions and 309 deletions

View File

@@ -4,7 +4,8 @@ import json
import os import os
import struct import struct
from functools import lru_cache from functools import lru_cache
from typing import List, Dict, Callable, Any from typing import List, Dict, Callable, Any, Optional
from dataclasses import dataclass
import aioboto3 import aioboto3
import aiohttp import aiohttp
@@ -59,39 +60,21 @@ async def openai_complete_if_cache(
openai_async_client = ( openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
) )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration kwargs.get("hashing_kv"), args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
) )
is_embedding_cache_enabled = embedding_cache_config["enabled"] if cached_response is not None:
if is_embedding_cache_enabled: return cached_response
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
if "response_format" in kwargs: if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse( response = await openai_async_client.beta.chat.completions.parse(
@@ -105,24 +88,21 @@ async def openai_complete_if_cache(
if r"\u" in content: if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape") content = content.encode("utf-8").decode("unicode_escape")
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ kwargs.get("hashing_kv"),
args_hash: { CacheData(
"return": content, args_hash=args_hash,
"model": model, content=content,
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val if is_embedding_cache_enabled else None, ),
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
) )
return content return content
@@ -155,6 +135,8 @@ async def azure_openai_complete_if_cache(
) )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
mode = kwargs.pop("mode", "default")
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
@@ -162,56 +144,35 @@ async def azure_openai_complete_if_cache(
if prompt is not None: if prompt is not None:
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration hashing_kv, args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
) )
is_embedding_cache_enabled = embedding_cache_config["enabled"] if cached_response is not None:
if is_embedding_cache_enabled: return cached_response
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create( response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
content = response.choices[0].message.content
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ hashing_kv,
args_hash: { CacheData(
"return": response.choices[0].message.content, args_hash=args_hash,
"model": model, content=content,
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val if is_embedding_cache_enabled else None, ),
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
) )
return response.choices[0].message.content
return content
class BedrockError(Exception): class BedrockError(Exception):
@@ -253,6 +214,15 @@ async def bedrock_complete_if_cache(
# Add user prompt # Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]}) messages.append({"role": "user", "content": [{"text": prompt}]})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
kwargs.get("hashing_kv"), args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Initialize Converse API arguments # Initialize Converse API arguments
args = {"modelId": model, "messages": messages} args = {"modelId": model, "messages": messages}
@@ -275,33 +245,14 @@ async def bedrock_complete_if_cache(
kwargs.pop(param) kwargs.pop(param)
) )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) # Handle cache
if hashing_kv is not None: mode = kwargs.pop("mode", "default")
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration kwargs.get("hashing_kv"), args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
) )
is_embedding_cache_enabled = embedding_cache_config["enabled"] if cached_response is not None:
if is_embedding_cache_enabled: return cached_response
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Call model via Converse API # Call model via Converse API
session = aioboto3.Session() session = aioboto3.Session()
@@ -311,27 +262,19 @@ async def bedrock_complete_if_cache(
except Exception as e: except Exception as e:
raise BedrockError(e) raise BedrockError(e)
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ kwargs.get("hashing_kv"),
args_hash: { CacheData(
"return": response["output"]["message"]["content"][0]["text"], args_hash=args_hash,
"model": model, content=response["output"]["message"]["content"][0]["text"],
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val ),
if is_embedding_cache_enabled
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
) )
return response["output"]["message"]["content"][0]["text"] return response["output"]["message"]["content"][0]["text"]
@@ -372,32 +315,14 @@ async def hf_model_if_cache(
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration hashing_kv, args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
) )
is_embedding_cache_enabled = embedding_cache_config["enabled"] if cached_response is not None:
if is_embedding_cache_enabled: return cached_response
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
input_prompt = "" input_prompt = ""
try: try:
@@ -442,24 +367,22 @@ async def hf_model_if_cache(
response_text = hf_tokenizer.decode( response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
) )
if hashing_kv is not None:
await hashing_kv.upsert( # Save to cache
{ await save_to_cache(
args_hash: { hashing_kv,
"return": response_text, CacheData(
"model": model, args_hash=args_hash,
"embedding": quantized.tobytes().hex() content=response_text,
if is_embedding_cache_enabled model=model,
else None, prompt=prompt,
"embedding_shape": quantized.shape quantized=quantized,
if is_embedding_cache_enabled min_val=min_val,
else None, max_val=max_val,
"embedding_min": min_val if is_embedding_cache_enabled else None, mode=mode,
"embedding_max": max_val if is_embedding_cache_enabled else None, ),
"original_prompt": prompt,
}
}
) )
return response_text return response_text
@@ -489,55 +412,34 @@ async def ollama_model_if_cache(
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration hashing_kv, args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
) )
is_embedding_cache_enabled = embedding_cache_config["enabled"] if cached_response is not None:
if is_embedding_cache_enabled: return cached_response
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await ollama_client.chat(model=model, messages=messages, **kwargs) response = await ollama_client.chat(model=model, messages=messages, **kwargs)
result = response["message"]["content"] result = response["message"]["content"]
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ hashing_kv,
args_hash: { CacheData(
"return": result, args_hash=args_hash,
"model": model, content=result,
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val if is_embedding_cache_enabled else None, ),
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
) )
return result return result
@@ -649,32 +551,14 @@ async def lmdeploy_model_if_cache(
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration hashing_kv, args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
) )
is_embedding_cache_enabled = embedding_cache_config["enabled"] if cached_response is not None:
if is_embedding_cache_enabled: return cached_response
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
gen_config = GenerationConfig( gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
@@ -692,24 +576,21 @@ async def lmdeploy_model_if_cache(
): ):
response += res.response response += res.response
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ hashing_kv,
args_hash: { CacheData(
"return": response, args_hash=args_hash,
"model": model, content=response,
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val if is_embedding_cache_enabled else None, ),
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
) )
return response return response
@@ -1139,6 +1020,75 @@ class MultiModel:
return await next_model.gen_func(**args) return await next_model.gen_func(**args)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
"""Generic cache handling function"""
if hashing_kv is None:
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
@dataclass
class CacheData:
args_hash: str
content: str
model: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None:
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"model": cache_data.model,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View File

@@ -474,7 +474,9 @@ async def kg_query(
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt, keyword_extraction=True) result = await use_model_func(
kw_prompt, keyword_extraction=True, mode=query_param.mode
)
logger.info("kw_prompt result:") logger.info("kw_prompt result:")
print(result) print(result)
try: try:
@@ -534,6 +536,7 @@ async def kg_query(
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
mode=query_param.mode,
) )
if len(response) > len(sys_prompt): if len(response) > len(sys_prompt):
response = ( response = (
@@ -1035,6 +1038,7 @@ async def naive_query(
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
mode=query_param.mode,
) )
if len(response) > len(sys_prompt): if len(response) > len(sys_prompt):

View File

@@ -310,19 +310,24 @@ def process_combine_contexts(hl, ll):
async def get_best_cached_response( async def get_best_cached_response(
hashing_kv, current_embedding, similarity_threshold=0.95 hashing_kv,
): current_embedding,
"""Get the cached response with the highest similarity""" similarity_threshold=0.95,
try: mode="default",
# Get all keys ) -> Union[str, None]:
all_keys = await hashing_kv.all_keys() # Get mode-specific cache
max_similarity = 0 mode_cache = await hashing_kv.get_by_id(mode)
best_cached_response = None if not mode_cache:
return None
# Get cached data one by one best_similarity = -1
for key in all_keys: best_response = None
cache_data = await hashing_kv.get_by_id(key) best_prompt = None
if cache_data is None or "embedding" not in cache_data: best_cache_id = None
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
if cache_data["embedding"] is None:
continue continue
# Convert cached embedding list to ndarray # Convert cached embedding list to ndarray
@@ -336,16 +341,25 @@ async def get_best_cached_response(
) )
similarity = cosine_similarity(current_embedding, cached_embedding) similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > max_similarity: if similarity > best_similarity:
max_similarity = similarity best_similarity = similarity
best_cached_response = cache_data["return"] best_response = cache_data["return"]
best_prompt = cache_data["original_prompt"]
best_cache_id = cache_id
if max_similarity > similarity_threshold: if best_similarity > similarity_threshold:
return best_cached_response prompt_display = (
return None best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
except Exception as e: log_data = {
logger.warning(f"Error in get_best_cached_response: {e}") "event": "cache_hit",
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.info(json.dumps(log_data))
return best_response
return None return None