Merge pull request #412 from magicyuan876/main

重构缓存逻辑
This commit is contained in:
zrguo
2024-12-06 18:10:48 +08:00
committed by GitHub
3 changed files with 295 additions and 314 deletions

View File

@@ -4,8 +4,8 @@ import json
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any, Union
from typing import List, Dict, Callable, Any, Union, Optional
from dataclasses import dataclass
import aioboto3
import aiohttp
import numpy as np
@@ -66,39 +66,22 @@ async def openai_complete_if_cache(
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
@@ -112,24 +95,21 @@ async def openai_complete_if_cache(
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": content,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return content
@@ -162,6 +142,8 @@ async def azure_openai_complete_if_cache(
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
mode = kwargs.pop("mode", "default")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@@ -169,56 +151,35 @@ async def azure_openai_complete_if_cache(
if prompt is not None:
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Handle cache
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": response.choices[0].message.content,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
return response.choices[0].message.content
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return content
class BedrockError(Exception):
@@ -259,6 +220,15 @@ async def bedrock_complete_if_cache(
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
@@ -281,34 +251,15 @@ async def bedrock_complete_if_cache(
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Call model via Converse API
session = aioboto3.Session()
@@ -318,30 +269,22 @@ async def bedrock_complete_if_cache(
except Exception as e:
raise BedrockError(e)
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": response["output"]["message"]["content"][0]["text"],
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val
if is_embedding_cache_enabled
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response["output"]["message"]["content"][0]["text"],
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response["output"]["message"]["content"][0]["text"]
return response["output"]["message"]["content"][0]["text"]
@lru_cache(maxsize=1)
@@ -379,32 +322,14 @@ async def hf_model_if_cache(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
input_prompt = ""
try:
@@ -449,24 +374,22 @@ async def hf_model_if_cache(
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": response_text,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response_text,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response_text
@@ -497,32 +420,14 @@ async def ollama_model_if_cache(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
@@ -535,29 +440,39 @@ async def ollama_model_if_cache(
return inner()
else:
result = response["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": result,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val
if is_embedding_cache_enabled
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
@lru_cache(maxsize=1)
@@ -668,32 +583,14 @@ async def lmdeploy_model_if_cache(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
@@ -711,24 +608,21 @@ async def lmdeploy_model_if_cache(
):
response += res.response
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": response,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response
@@ -1158,6 +1052,75 @@ class MultiModel:
return await next_model.gen_func(**args)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
"""Generic cache handling function"""
if hashing_kv is None:
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
@dataclass
class CacheData:
args_hash: str
content: str
model: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None:
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"model": cache_data.model,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})
if __name__ == "__main__":
import asyncio

View File

@@ -474,7 +474,9 @@ async def kg_query(
use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt, keyword_extraction=True)
result = await use_model_func(
kw_prompt, keyword_extraction=True, mode=query_param.mode
)
logger.info("kw_prompt result:")
print(result)
try:
@@ -535,6 +537,7 @@ async def kg_query(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
mode=query_param.mode,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
@@ -1036,6 +1039,7 @@ async def naive_query(
response = await use_model_func(
query,
system_prompt=sys_prompt,
mode=query_param.mode,
)
if len(response) > len(sys_prompt):

View File

@@ -310,43 +310,57 @@ def process_combine_contexts(hl, ll):
async def get_best_cached_response(
hashing_kv, current_embedding, similarity_threshold=0.95
):
"""Get the cached response with the highest similarity"""
try:
# Get all keys
all_keys = await hashing_kv.all_keys()
max_similarity = 0
best_cached_response = None
# Get cached data one by one
for key in all_keys:
cache_data = await hashing_kv.get_by_id(key)
if cache_data is None or "embedding" not in cache_data:
continue
# Convert cached embedding list to ndarray
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
cached_embedding = dequantize_embedding(
cached_quantized,
cache_data["embedding_min"],
cache_data["embedding_max"],
)
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > max_similarity:
max_similarity = similarity
best_cached_response = cache_data["return"]
if max_similarity > similarity_threshold:
return best_cached_response
hashing_kv,
current_embedding,
similarity_threshold=0.95,
mode="default",
) -> Union[str, None]:
# Get mode-specific cache
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
except Exception as e:
logger.warning(f"Error in get_best_cached_response: {e}")
return None
best_similarity = -1
best_response = None
best_prompt = None
best_cache_id = None
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
if cache_data["embedding"] is None:
continue
# Convert cached embedding list to ndarray
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
cached_embedding = dequantize_embedding(
cached_quantized,
cache_data["embedding_min"],
cache_data["embedding_max"],
)
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_response = cache_data["return"]
best_prompt = cache_data["original_prompt"]
best_cache_id = cache_id
if best_similarity > similarity_threshold:
prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
log_data = {
"event": "cache_hit",
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.info(json.dumps(log_data, ensure_ascii=False))
return best_response
return None
def cosine_similarity(v1, v2):