feat(lightrag): 添加 查询时使用embedding缓存功能

- 在 LightRAG 类中添加 embedding_cache_config配置项
- 实现基于 embedding 相似度的缓存查询和存储
- 添加量化和反量化函数,用于压缩 embedding 数据
- 新增示例演示 embedding 缓存的使用
This commit is contained in:
magicyuan876
2024-12-06 08:17:20 +08:00
parent 645890aff6
commit d48c6e4588
5 changed files with 431 additions and 34 deletions

View File

@@ -33,6 +33,8 @@ from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
quantize_embedding,
get_best_cached_response,
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -65,10 +67,29 @@ async def openai_complete_if_cache(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
@@ -81,10 +102,24 @@ async def openai_complete_if_cache(
content = response.choices[0].message.content
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
# print(content)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
{
args_hash: {
"return": content,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
return content
@@ -125,10 +160,28 @@ async def azure_openai_complete_if_cache(
if prompt is not None:
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
@@ -136,7 +189,21 @@ async def azure_openai_complete_if_cache(
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
{
args_hash: {
"return": response.choices[0].message.content,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
return response.choices[0].message.content
@@ -204,10 +271,29 @@ async def bedrock_complete_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Call model via Converse API
session = aioboto3.Session()
@@ -223,6 +309,19 @@ async def bedrock_complete_if_cache(
args_hash: {
"return": response["output"]["message"]["content"][0]["text"],
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val
if is_embedding_cache_enabled
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
)
@@ -245,7 +344,11 @@ def initialize_hf_model(model_name):
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
@@ -257,10 +360,30 @@ async def hf_model_if_cache(
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
input_prompt = ""
try:
input_prompt = hf_tokenizer.apply_chat_template(
@@ -305,12 +428,32 @@ async def hf_model_if_cache(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
await hashing_kv.upsert(
{
args_hash: {
"return": response_text,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
return response_text
async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
@@ -326,18 +469,52 @@ async def ollama_model_if_cache(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
result = response["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
await hashing_kv.upsert(
{
args_hash: {
"return": result,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
return result
@@ -444,10 +621,29 @@ async def lmdeploy_model_if_cache(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
)
if best_cached_response is not None:
return best_cached_response
else:
# Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
@@ -466,7 +662,23 @@ async def lmdeploy_model_if_cache(
response += res.response
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
await hashing_kv.upsert(
{
args_hash: {
"return": response,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
}
}
)
return response