修复 args_hash在使用常规缓存时候才计算导致embedding缓存时没有计算的bug

This commit is contained in:
magicyuan876
2024-12-06 10:21:53 +08:00
parent 2ecdab2f18
commit 6540d11096
2 changed files with 47 additions and 7 deletions

View File

@@ -66,7 +66,11 @@ async def openai_complete_if_cache(
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -86,7 +90,6 @@ async def openai_complete_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
@@ -159,7 +162,12 @@ async def azure_openai_complete_if_cache(
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -178,7 +186,7 @@ async def azure_openai_complete_if_cache(
if best_cached_response is not None:
return best_cached_response
else:
args_hash = compute_args_hash(model, messages)
# Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
@@ -271,6 +279,9 @@ async def bedrock_complete_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -290,7 +301,6 @@ async def bedrock_complete_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
@@ -343,6 +353,11 @@ def initialize_hf_model(model_name):
return hf_model, hf_tokenizer
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def hf_model_if_cache(
model,
prompt,
@@ -359,7 +374,11 @@ async def hf_model_if_cache(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -379,7 +398,6 @@ async def hf_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
@@ -448,6 +466,11 @@ async def hf_model_if_cache(
return response_text
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def ollama_model_if_cache(
model,
prompt,
@@ -468,7 +491,12 @@ async def ollama_model_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -488,7 +516,6 @@ async def ollama_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
@@ -542,6 +569,11 @@ def initialize_lmdeploy_pipeline(
return lmdeploy_pipe
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def lmdeploy_model_if_cache(
model,
prompt,
@@ -620,7 +652,12 @@ async def lmdeploy_model_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -640,7 +677,6 @@ async def lmdeploy_model_if_cache(
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]