Merge pull request #412 from magicyuan876/main

重构缓存逻辑
This commit is contained in:
zrguo
2024-12-06 18:10:48 +08:00
committed by GitHub
3 changed files with 295 additions and 314 deletions

View File

@@ -4,8 +4,8 @@ import json
import os import os
import struct import struct
from functools import lru_cache from functools import lru_cache
from typing import List, Dict, Callable, Any, Union from typing import List, Dict, Callable, Any, Union, Optional
from dataclasses import dataclass
import aioboto3 import aioboto3
import aiohttp import aiohttp
import numpy as np import numpy as np
@@ -66,39 +66,22 @@ async def openai_complete_if_cache(
openai_async_client = ( openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
) )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Calculate args_hash only when using cache # Handle cache
args_hash = compute_args_hash(model, messages) mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration cached_response, quantized, min_val, max_val = await handle_cache(
embedding_cache_config = hashing_kv.global_config.get( hashing_kv, args_hash, prompt, mode
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} )
) if cached_response is not None:
is_embedding_cache_enabled = embedding_cache_config["enabled"] return cached_response
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
if "response_format" in kwargs: if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse( response = await openai_async_client.beta.chat.completions.parse(
@@ -112,24 +95,21 @@ async def openai_complete_if_cache(
if r"\u" in content: if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape") content = content.encode("utf-8").decode("unicode_escape")
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ hashing_kv,
args_hash: { CacheData(
"return": content, args_hash=args_hash,
"model": model, content=content,
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val if is_embedding_cache_enabled else None, ),
"embedding_max": max_val if is_embedding_cache_enabled else None, )
"original_prompt": prompt,
}
}
)
return content return content
@@ -162,6 +142,8 @@ async def azure_openai_complete_if_cache(
) )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
mode = kwargs.pop("mode", "default")
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
@@ -169,56 +151,35 @@ async def azure_openai_complete_if_cache(
if prompt is not None: if prompt is not None:
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages)
args_hash = compute_args_hash(model, messages) cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
# Get embedding cache configuration )
embedding_cache_config = hashing_kv.global_config.get( if cached_response is not None:
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} return cached_response
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create( response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs model=model, messages=messages, **kwargs
) )
content = response.choices[0].message.content
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ hashing_kv,
args_hash: { CacheData(
"return": response.choices[0].message.content, args_hash=args_hash,
"model": model, content=content,
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val if is_embedding_cache_enabled else None, ),
"embedding_max": max_val if is_embedding_cache_enabled else None, )
"original_prompt": prompt,
} return content
}
)
return response.choices[0].message.content
class BedrockError(Exception): class BedrockError(Exception):
@@ -259,6 +220,15 @@ async def bedrock_complete_if_cache(
# Add user prompt # Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]}) messages.append({"role": "user", "content": [{"text": prompt}]})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Initialize Converse API arguments # Initialize Converse API arguments
args = {"modelId": model, "messages": messages} args = {"modelId": model, "messages": messages}
@@ -281,34 +251,15 @@ async def bedrock_complete_if_cache(
args["inferenceConfig"][inference_params_map.get(param, param)] = ( args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param) kwargs.pop(param)
) )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration hashing_kv, args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get( )
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} if cached_response is not None:
) return cached_response
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Call model via Converse API # Call model via Converse API
session = aioboto3.Session() session = aioboto3.Session()
@@ -318,30 +269,22 @@ async def bedrock_complete_if_cache(
except Exception as e: except Exception as e:
raise BedrockError(e) raise BedrockError(e)
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ hashing_kv,
args_hash: { CacheData(
"return": response["output"]["message"]["content"][0]["text"], args_hash=args_hash,
"model": model, content=response["output"]["message"]["content"][0]["text"],
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val ),
if is_embedding_cache_enabled )
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
)
return response["output"]["message"]["content"][0]["text"] return response["output"]["message"]["content"][0]["text"]
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
@@ -379,32 +322,14 @@ async def hf_model_if_cache(
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration hashing_kv, args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get( )
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} if cached_response is not None:
) return cached_response
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
input_prompt = "" input_prompt = ""
try: try:
@@ -449,24 +374,22 @@ async def hf_model_if_cache(
response_text = hf_tokenizer.decode( response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
) )
if hashing_kv is not None:
await hashing_kv.upsert( # Save to cache
{ await save_to_cache(
args_hash: { hashing_kv,
"return": response_text, CacheData(
"model": model, args_hash=args_hash,
"embedding": quantized.tobytes().hex() content=response_text,
if is_embedding_cache_enabled model=model,
else None, prompt=prompt,
"embedding_shape": quantized.shape quantized=quantized,
if is_embedding_cache_enabled min_val=min_val,
else None, max_val=max_val,
"embedding_min": min_val if is_embedding_cache_enabled else None, mode=mode,
"embedding_max": max_val if is_embedding_cache_enabled else None, ),
"original_prompt": prompt, )
}
}
)
return response_text return response_text
@@ -497,32 +420,14 @@ async def ollama_model_if_cache(
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration hashing_kv, args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get( )
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} if cached_response is not None:
) return cached_response
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await ollama_client.chat(model=model, messages=messages, **kwargs) response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream: if stream:
@@ -535,29 +440,39 @@ async def ollama_model_if_cache(
return inner() return inner()
else: else:
result = response["message"]["content"] result = response["message"]["content"]
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ hashing_kv,
args_hash: { CacheData(
"return": result, args_hash=args_hash,
"model": model, content=result,
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val ),
if is_embedding_cache_enabled )
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
)
return result return result
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
@@ -668,32 +583,14 @@ async def lmdeploy_model_if_cache(
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: # Handle cache
# Calculate args_hash only when using cache mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
# Get embedding cache configuration hashing_kv, args_hash, prompt, mode
embedding_cache_config = hashing_kv.global_config.get( )
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} if cached_response is not None:
) return cached_response
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
gen_config = GenerationConfig( gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
@@ -711,24 +608,21 @@ async def lmdeploy_model_if_cache(
): ):
response += res.response response += res.response
if hashing_kv is not None: # Save to cache
await hashing_kv.upsert( await save_to_cache(
{ hashing_kv,
args_hash: { CacheData(
"return": response, args_hash=args_hash,
"model": model, content=response,
"embedding": quantized.tobytes().hex() model=model,
if is_embedding_cache_enabled prompt=prompt,
else None, quantized=quantized,
"embedding_shape": quantized.shape min_val=min_val,
if is_embedding_cache_enabled max_val=max_val,
else None, mode=mode,
"embedding_min": min_val if is_embedding_cache_enabled else None, ),
"embedding_max": max_val if is_embedding_cache_enabled else None, )
"original_prompt": prompt,
}
}
)
return response return response
@@ -1158,6 +1052,75 @@ class MultiModel:
return await next_model.gen_func(**args) return await next_model.gen_func(**args)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
"""Generic cache handling function"""
if hashing_kv is None:
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
@dataclass
class CacheData:
args_hash: str
content: str
model: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None:
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"model": cache_data.model,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View File

@@ -474,7 +474,9 @@ async def kg_query(
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt, keyword_extraction=True) result = await use_model_func(
kw_prompt, keyword_extraction=True, mode=query_param.mode
)
logger.info("kw_prompt result:") logger.info("kw_prompt result:")
print(result) print(result)
try: try:
@@ -535,6 +537,7 @@ async def kg_query(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
stream=query_param.stream, stream=query_param.stream,
mode=query_param.mode,
) )
if isinstance(response, str) and len(response) > len(sys_prompt): if isinstance(response, str) and len(response) > len(sys_prompt):
response = ( response = (
@@ -1036,6 +1039,7 @@ async def naive_query(
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
mode=query_param.mode,
) )
if len(response) > len(sys_prompt): if len(response) > len(sys_prompt):

View File

@@ -310,43 +310,57 @@ def process_combine_contexts(hl, ll):
async def get_best_cached_response( async def get_best_cached_response(
hashing_kv, current_embedding, similarity_threshold=0.95 hashing_kv,
): current_embedding,
"""Get the cached response with the highest similarity""" similarity_threshold=0.95,
try: mode="default",
# Get all keys ) -> Union[str, None]:
all_keys = await hashing_kv.all_keys() # Get mode-specific cache
max_similarity = 0 mode_cache = await hashing_kv.get_by_id(mode)
best_cached_response = None if not mode_cache:
# Get cached data one by one
for key in all_keys:
cache_data = await hashing_kv.get_by_id(key)
if cache_data is None or "embedding" not in cache_data:
continue
# Convert cached embedding list to ndarray
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
cached_embedding = dequantize_embedding(
cached_quantized,
cache_data["embedding_min"],
cache_data["embedding_max"],
)
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > max_similarity:
max_similarity = similarity
best_cached_response = cache_data["return"]
if max_similarity > similarity_threshold:
return best_cached_response
return None return None
except Exception as e: best_similarity = -1
logger.warning(f"Error in get_best_cached_response: {e}") best_response = None
return None best_prompt = None
best_cache_id = None
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
if cache_data["embedding"] is None:
continue
# Convert cached embedding list to ndarray
cached_quantized = np.frombuffer(
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
).reshape(cache_data["embedding_shape"])
cached_embedding = dequantize_embedding(
cached_quantized,
cache_data["embedding_min"],
cache_data["embedding_max"],
)
similarity = cosine_similarity(current_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_response = cache_data["return"]
best_prompt = cache_data["original_prompt"]
best_cache_id = cache_id
if best_similarity > similarity_threshold:
prompt_display = (
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
)
log_data = {
"event": "cache_hit",
"mode": mode,
"similarity": round(best_similarity, 4),
"cache_id": best_cache_id,
"original_prompt": prompt_display,
}
logger.info(json.dumps(log_data, ensure_ascii=False))
return best_response
return None
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2):