重构缓存处理逻辑
- 提取通用缓存处理逻辑到新函数 handle_cache 和 save_to_cache - 使用 CacheData 类统一缓存数据结构 - 优化嵌入式缓存和常规缓存的处理流程 - 添加模式参数以支持不同查询模式的缓存策略 - 重构 get_best_cached_response 函数,提高缓存查询效率
This commit is contained in:
458
lightrag/llm.py
458
lightrag/llm.py
@@ -4,7 +4,8 @@ import json
|
|||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Dict, Callable, Any
|
from typing import List, Dict, Callable, Any, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import aioboto3
|
import aioboto3
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -59,39 +60,21 @@ async def openai_complete_if_cache(
|
|||||||
openai_async_client = (
|
openai_async_client = (
|
||||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||||
)
|
)
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
messages.extend(history_messages)
|
messages.extend(history_messages)
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Handle cache
|
||||||
# Calculate args_hash only when using cache
|
mode = kwargs.pop("mode", "default")
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
# Get embedding cache configuration
|
kwargs.get("hashing_kv"), args_hash, prompt, mode
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
|
||||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
||||||
)
|
)
|
||||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
if cached_response is not None:
|
||||||
if is_embedding_cache_enabled:
|
return cached_response
|
||||||
# Use embedding cache
|
|
||||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
|
||||||
current_embedding = await embedding_model_func([prompt])
|
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
|
||||||
best_cached_response = await get_best_cached_response(
|
|
||||||
hashing_kv,
|
|
||||||
current_embedding[0],
|
|
||||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
|
||||||
)
|
|
||||||
if best_cached_response is not None:
|
|
||||||
return best_cached_response
|
|
||||||
else:
|
|
||||||
# Use regular cache
|
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
|
||||||
if if_cache_return is not None:
|
|
||||||
return if_cache_return["return"]
|
|
||||||
|
|
||||||
if "response_format" in kwargs:
|
if "response_format" in kwargs:
|
||||||
response = await openai_async_client.beta.chat.completions.parse(
|
response = await openai_async_client.beta.chat.completions.parse(
|
||||||
@@ -105,24 +88,21 @@ async def openai_complete_if_cache(
|
|||||||
if r"\u" in content:
|
if r"\u" in content:
|
||||||
content = content.encode("utf-8").decode("unicode_escape")
|
content = content.encode("utf-8").decode("unicode_escape")
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Save to cache
|
||||||
await hashing_kv.upsert(
|
await save_to_cache(
|
||||||
{
|
kwargs.get("hashing_kv"),
|
||||||
args_hash: {
|
CacheData(
|
||||||
"return": content,
|
args_hash=args_hash,
|
||||||
"model": model,
|
content=content,
|
||||||
"embedding": quantized.tobytes().hex()
|
model=model,
|
||||||
if is_embedding_cache_enabled
|
prompt=prompt,
|
||||||
else None,
|
quantized=quantized,
|
||||||
"embedding_shape": quantized.shape
|
min_val=min_val,
|
||||||
if is_embedding_cache_enabled
|
max_val=max_val,
|
||||||
else None,
|
mode=mode,
|
||||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
),
|
||||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
|
||||||
"original_prompt": prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
@@ -155,6 +135,8 @@ async def azure_openai_complete_if_cache(
|
|||||||
)
|
)
|
||||||
|
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
|
mode = kwargs.pop("mode", "default")
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
@@ -162,56 +144,35 @@ async def azure_openai_complete_if_cache(
|
|||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Handle cache
|
||||||
# Calculate args_hash only when using cache
|
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
# Get embedding cache configuration
|
hashing_kv, args_hash, prompt, mode
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
|
||||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
||||||
)
|
)
|
||||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
if cached_response is not None:
|
||||||
if is_embedding_cache_enabled:
|
return cached_response
|
||||||
# Use embedding cache
|
|
||||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
|
||||||
current_embedding = await embedding_model_func([prompt])
|
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
|
||||||
best_cached_response = await get_best_cached_response(
|
|
||||||
hashing_kv,
|
|
||||||
current_embedding[0],
|
|
||||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
|
||||||
)
|
|
||||||
if best_cached_response is not None:
|
|
||||||
return best_cached_response
|
|
||||||
else:
|
|
||||||
# Use regular cache
|
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
|
||||||
if if_cache_return is not None:
|
|
||||||
return if_cache_return["return"]
|
|
||||||
|
|
||||||
response = await openai_async_client.chat.completions.create(
|
response = await openai_async_client.chat.completions.create(
|
||||||
model=model, messages=messages, **kwargs
|
model=model, messages=messages, **kwargs
|
||||||
)
|
)
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Save to cache
|
||||||
await hashing_kv.upsert(
|
await save_to_cache(
|
||||||
{
|
hashing_kv,
|
||||||
args_hash: {
|
CacheData(
|
||||||
"return": response.choices[0].message.content,
|
args_hash=args_hash,
|
||||||
"model": model,
|
content=content,
|
||||||
"embedding": quantized.tobytes().hex()
|
model=model,
|
||||||
if is_embedding_cache_enabled
|
prompt=prompt,
|
||||||
else None,
|
quantized=quantized,
|
||||||
"embedding_shape": quantized.shape
|
min_val=min_val,
|
||||||
if is_embedding_cache_enabled
|
max_val=max_val,
|
||||||
else None,
|
mode=mode,
|
||||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
),
|
||||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
|
||||||
"original_prompt": prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
class BedrockError(Exception):
|
class BedrockError(Exception):
|
||||||
@@ -253,6 +214,15 @@ async def bedrock_complete_if_cache(
|
|||||||
# Add user prompt
|
# Add user prompt
|
||||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||||
|
|
||||||
|
# Handle cache
|
||||||
|
mode = kwargs.pop("mode", "default")
|
||||||
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
|
kwargs.get("hashing_kv"), args_hash, prompt, mode
|
||||||
|
)
|
||||||
|
if cached_response is not None:
|
||||||
|
return cached_response
|
||||||
|
|
||||||
# Initialize Converse API arguments
|
# Initialize Converse API arguments
|
||||||
args = {"modelId": model, "messages": messages}
|
args = {"modelId": model, "messages": messages}
|
||||||
|
|
||||||
@@ -275,33 +245,14 @@ async def bedrock_complete_if_cache(
|
|||||||
kwargs.pop(param)
|
kwargs.pop(param)
|
||||||
)
|
)
|
||||||
|
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
# Handle cache
|
||||||
if hashing_kv is not None:
|
mode = kwargs.pop("mode", "default")
|
||||||
# Calculate args_hash only when using cache
|
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
# Get embedding cache configuration
|
kwargs.get("hashing_kv"), args_hash, prompt, mode
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
|
||||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
||||||
)
|
)
|
||||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
if cached_response is not None:
|
||||||
if is_embedding_cache_enabled:
|
return cached_response
|
||||||
# Use embedding cache
|
|
||||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
|
||||||
current_embedding = await embedding_model_func([prompt])
|
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
|
||||||
best_cached_response = await get_best_cached_response(
|
|
||||||
hashing_kv,
|
|
||||||
current_embedding[0],
|
|
||||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
|
||||||
)
|
|
||||||
if best_cached_response is not None:
|
|
||||||
return best_cached_response
|
|
||||||
else:
|
|
||||||
# Use regular cache
|
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
|
||||||
if if_cache_return is not None:
|
|
||||||
return if_cache_return["return"]
|
|
||||||
|
|
||||||
# Call model via Converse API
|
# Call model via Converse API
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
@@ -311,27 +262,19 @@ async def bedrock_complete_if_cache(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BedrockError(e)
|
raise BedrockError(e)
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Save to cache
|
||||||
await hashing_kv.upsert(
|
await save_to_cache(
|
||||||
{
|
kwargs.get("hashing_kv"),
|
||||||
args_hash: {
|
CacheData(
|
||||||
"return": response["output"]["message"]["content"][0]["text"],
|
args_hash=args_hash,
|
||||||
"model": model,
|
content=response["output"]["message"]["content"][0]["text"],
|
||||||
"embedding": quantized.tobytes().hex()
|
model=model,
|
||||||
if is_embedding_cache_enabled
|
prompt=prompt,
|
||||||
else None,
|
quantized=quantized,
|
||||||
"embedding_shape": quantized.shape
|
min_val=min_val,
|
||||||
if is_embedding_cache_enabled
|
max_val=max_val,
|
||||||
else None,
|
mode=mode,
|
||||||
"embedding_min": min_val
|
),
|
||||||
if is_embedding_cache_enabled
|
|
||||||
else None,
|
|
||||||
"embedding_max": max_val
|
|
||||||
if is_embedding_cache_enabled
|
|
||||||
else None,
|
|
||||||
"original_prompt": prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response["output"]["message"]["content"][0]["text"]
|
return response["output"]["message"]["content"][0]["text"]
|
||||||
@@ -372,32 +315,14 @@ async def hf_model_if_cache(
|
|||||||
messages.extend(history_messages)
|
messages.extend(history_messages)
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Handle cache
|
||||||
# Calculate args_hash only when using cache
|
mode = kwargs.pop("mode", "default")
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
# Get embedding cache configuration
|
hashing_kv, args_hash, prompt, mode
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
|
||||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
||||||
)
|
)
|
||||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
if cached_response is not None:
|
||||||
if is_embedding_cache_enabled:
|
return cached_response
|
||||||
# Use embedding cache
|
|
||||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
|
||||||
current_embedding = await embedding_model_func([prompt])
|
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
|
||||||
best_cached_response = await get_best_cached_response(
|
|
||||||
hashing_kv,
|
|
||||||
current_embedding[0],
|
|
||||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
|
||||||
)
|
|
||||||
if best_cached_response is not None:
|
|
||||||
return best_cached_response
|
|
||||||
else:
|
|
||||||
# Use regular cache
|
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
|
||||||
if if_cache_return is not None:
|
|
||||||
return if_cache_return["return"]
|
|
||||||
|
|
||||||
input_prompt = ""
|
input_prompt = ""
|
||||||
try:
|
try:
|
||||||
@@ -442,24 +367,22 @@ async def hf_model_if_cache(
|
|||||||
response_text = hf_tokenizer.decode(
|
response_text = hf_tokenizer.decode(
|
||||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||||
)
|
)
|
||||||
if hashing_kv is not None:
|
|
||||||
await hashing_kv.upsert(
|
# Save to cache
|
||||||
{
|
await save_to_cache(
|
||||||
args_hash: {
|
hashing_kv,
|
||||||
"return": response_text,
|
CacheData(
|
||||||
"model": model,
|
args_hash=args_hash,
|
||||||
"embedding": quantized.tobytes().hex()
|
content=response_text,
|
||||||
if is_embedding_cache_enabled
|
model=model,
|
||||||
else None,
|
prompt=prompt,
|
||||||
"embedding_shape": quantized.shape
|
quantized=quantized,
|
||||||
if is_embedding_cache_enabled
|
min_val=min_val,
|
||||||
else None,
|
max_val=max_val,
|
||||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
mode=mode,
|
||||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
),
|
||||||
"original_prompt": prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response_text
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
@@ -489,55 +412,34 @@ async def ollama_model_if_cache(
|
|||||||
messages.extend(history_messages)
|
messages.extend(history_messages)
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Handle cache
|
||||||
# Calculate args_hash only when using cache
|
mode = kwargs.pop("mode", "default")
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
# Get embedding cache configuration
|
hashing_kv, args_hash, prompt, mode
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
|
||||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
||||||
)
|
)
|
||||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
if cached_response is not None:
|
||||||
if is_embedding_cache_enabled:
|
return cached_response
|
||||||
# Use embedding cache
|
|
||||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
|
||||||
current_embedding = await embedding_model_func([prompt])
|
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
|
||||||
best_cached_response = await get_best_cached_response(
|
|
||||||
hashing_kv,
|
|
||||||
current_embedding[0],
|
|
||||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
|
||||||
)
|
|
||||||
if best_cached_response is not None:
|
|
||||||
return best_cached_response
|
|
||||||
else:
|
|
||||||
# Use regular cache
|
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
|
||||||
if if_cache_return is not None:
|
|
||||||
return if_cache_return["return"]
|
|
||||||
|
|
||||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||||
|
|
||||||
result = response["message"]["content"]
|
result = response["message"]["content"]
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Save to cache
|
||||||
await hashing_kv.upsert(
|
await save_to_cache(
|
||||||
{
|
hashing_kv,
|
||||||
args_hash: {
|
CacheData(
|
||||||
"return": result,
|
args_hash=args_hash,
|
||||||
"model": model,
|
content=result,
|
||||||
"embedding": quantized.tobytes().hex()
|
model=model,
|
||||||
if is_embedding_cache_enabled
|
prompt=prompt,
|
||||||
else None,
|
quantized=quantized,
|
||||||
"embedding_shape": quantized.shape
|
min_val=min_val,
|
||||||
if is_embedding_cache_enabled
|
max_val=max_val,
|
||||||
else None,
|
mode=mode,
|
||||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
),
|
||||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
|
||||||
"original_prompt": prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -649,32 +551,14 @@ async def lmdeploy_model_if_cache(
|
|||||||
messages.extend(history_messages)
|
messages.extend(history_messages)
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Handle cache
|
||||||
# Calculate args_hash only when using cache
|
mode = kwargs.pop("mode", "default")
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
# Get embedding cache configuration
|
hashing_kv, args_hash, prompt, mode
|
||||||
embedding_cache_config = hashing_kv.global_config.get(
|
|
||||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
|
||||||
)
|
)
|
||||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
if cached_response is not None:
|
||||||
if is_embedding_cache_enabled:
|
return cached_response
|
||||||
# Use embedding cache
|
|
||||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
|
||||||
current_embedding = await embedding_model_func([prompt])
|
|
||||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
|
||||||
best_cached_response = await get_best_cached_response(
|
|
||||||
hashing_kv,
|
|
||||||
current_embedding[0],
|
|
||||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
|
||||||
)
|
|
||||||
if best_cached_response is not None:
|
|
||||||
return best_cached_response
|
|
||||||
else:
|
|
||||||
# Use regular cache
|
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
|
||||||
if if_cache_return is not None:
|
|
||||||
return if_cache_return["return"]
|
|
||||||
|
|
||||||
gen_config = GenerationConfig(
|
gen_config = GenerationConfig(
|
||||||
skip_special_tokens=skip_special_tokens,
|
skip_special_tokens=skip_special_tokens,
|
||||||
@@ -692,24 +576,21 @@ async def lmdeploy_model_if_cache(
|
|||||||
):
|
):
|
||||||
response += res.response
|
response += res.response
|
||||||
|
|
||||||
if hashing_kv is not None:
|
# Save to cache
|
||||||
await hashing_kv.upsert(
|
await save_to_cache(
|
||||||
{
|
hashing_kv,
|
||||||
args_hash: {
|
CacheData(
|
||||||
"return": response,
|
args_hash=args_hash,
|
||||||
"model": model,
|
content=response,
|
||||||
"embedding": quantized.tobytes().hex()
|
model=model,
|
||||||
if is_embedding_cache_enabled
|
prompt=prompt,
|
||||||
else None,
|
quantized=quantized,
|
||||||
"embedding_shape": quantized.shape
|
min_val=min_val,
|
||||||
if is_embedding_cache_enabled
|
max_val=max_val,
|
||||||
else None,
|
mode=mode,
|
||||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
),
|
||||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
|
||||||
"original_prompt": prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@@ -1139,6 +1020,75 @@ class MultiModel:
|
|||||||
return await next_model.gen_func(**args)
|
return await next_model.gen_func(**args)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||||
|
"""Generic cache handling function"""
|
||||||
|
if hashing_kv is None:
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
# Get embedding cache configuration
|
||||||
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
|
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||||
|
|
||||||
|
quantized = min_val = max_val = None
|
||||||
|
if is_embedding_cache_enabled:
|
||||||
|
# Use embedding cache
|
||||||
|
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||||
|
current_embedding = await embedding_model_func([prompt])
|
||||||
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
|
best_cached_response = await get_best_cached_response(
|
||||||
|
hashing_kv,
|
||||||
|
current_embedding[0],
|
||||||
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
if best_cached_response is not None:
|
||||||
|
return best_cached_response, None, None, None
|
||||||
|
else:
|
||||||
|
# Use regular cache
|
||||||
|
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||||
|
if args_hash in mode_cache:
|
||||||
|
return mode_cache[args_hash]["return"], None, None, None
|
||||||
|
|
||||||
|
return None, quantized, min_val, max_val
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheData:
|
||||||
|
args_hash: str
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
quantized: Optional[np.ndarray] = None
|
||||||
|
min_val: Optional[float] = None
|
||||||
|
max_val: Optional[float] = None
|
||||||
|
mode: str = "default"
|
||||||
|
|
||||||
|
|
||||||
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
|
if hashing_kv is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||||
|
|
||||||
|
mode_cache[cache_data.args_hash] = {
|
||||||
|
"return": cache_data.content,
|
||||||
|
"model": cache_data.model,
|
||||||
|
"embedding": cache_data.quantized.tobytes().hex()
|
||||||
|
if cache_data.quantized is not None
|
||||||
|
else None,
|
||||||
|
"embedding_shape": cache_data.quantized.shape
|
||||||
|
if cache_data.quantized is not None
|
||||||
|
else None,
|
||||||
|
"embedding_min": cache_data.min_val,
|
||||||
|
"embedding_max": cache_data.max_val,
|
||||||
|
"original_prompt": cache_data.prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
@@ -474,7 +474,9 @@ async def kg_query(
|
|||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
||||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
result = await use_model_func(
|
||||||
|
kw_prompt, keyword_extraction=True, mode=query_param.mode
|
||||||
|
)
|
||||||
logger.info("kw_prompt result:")
|
logger.info("kw_prompt result:")
|
||||||
print(result)
|
print(result)
|
||||||
try:
|
try:
|
||||||
@@ -534,6 +536,7 @@ async def kg_query(
|
|||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
|
mode=query_param.mode,
|
||||||
)
|
)
|
||||||
if len(response) > len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = (
|
response = (
|
||||||
@@ -1035,6 +1038,7 @@ async def naive_query(
|
|||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
|
mode=query_param.mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(response) > len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
|
@@ -310,19 +310,24 @@ def process_combine_contexts(hl, ll):
|
|||||||
|
|
||||||
|
|
||||||
async def get_best_cached_response(
|
async def get_best_cached_response(
|
||||||
hashing_kv, current_embedding, similarity_threshold=0.95
|
hashing_kv,
|
||||||
):
|
current_embedding,
|
||||||
"""Get the cached response with the highest similarity"""
|
similarity_threshold=0.95,
|
||||||
try:
|
mode="default",
|
||||||
# Get all keys
|
) -> Union[str, None]:
|
||||||
all_keys = await hashing_kv.all_keys()
|
# Get mode-specific cache
|
||||||
max_similarity = 0
|
mode_cache = await hashing_kv.get_by_id(mode)
|
||||||
best_cached_response = None
|
if not mode_cache:
|
||||||
|
return None
|
||||||
|
|
||||||
# Get cached data one by one
|
best_similarity = -1
|
||||||
for key in all_keys:
|
best_response = None
|
||||||
cache_data = await hashing_kv.get_by_id(key)
|
best_prompt = None
|
||||||
if cache_data is None or "embedding" not in cache_data:
|
best_cache_id = None
|
||||||
|
|
||||||
|
# Only iterate through cache entries for this mode
|
||||||
|
for cache_id, cache_data in mode_cache.items():
|
||||||
|
if cache_data["embedding"] is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Convert cached embedding list to ndarray
|
# Convert cached embedding list to ndarray
|
||||||
@@ -336,16 +341,25 @@ async def get_best_cached_response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
similarity = cosine_similarity(current_embedding, cached_embedding)
|
similarity = cosine_similarity(current_embedding, cached_embedding)
|
||||||
if similarity > max_similarity:
|
if similarity > best_similarity:
|
||||||
max_similarity = similarity
|
best_similarity = similarity
|
||||||
best_cached_response = cache_data["return"]
|
best_response = cache_data["return"]
|
||||||
|
best_prompt = cache_data["original_prompt"]
|
||||||
|
best_cache_id = cache_id
|
||||||
|
|
||||||
if max_similarity > similarity_threshold:
|
if best_similarity > similarity_threshold:
|
||||||
return best_cached_response
|
prompt_display = (
|
||||||
return None
|
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
||||||
|
)
|
||||||
except Exception as e:
|
log_data = {
|
||||||
logger.warning(f"Error in get_best_cached_response: {e}")
|
"event": "cache_hit",
|
||||||
|
"mode": mode,
|
||||||
|
"similarity": round(best_similarity, 4),
|
||||||
|
"cache_id": best_cache_id,
|
||||||
|
"original_prompt": prompt_display,
|
||||||
|
}
|
||||||
|
logger.info(json.dumps(log_data))
|
||||||
|
return best_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user