@@ -596,6 +596,7 @@ if __name__ == "__main__":
|
|||||||
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
|
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
|
||||||
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
|
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
|
||||||
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
|
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
|
||||||
|
| **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` |
|
||||||
|
|
||||||
## API Server Implementation
|
## API Server Implementation
|
||||||
|
|
||||||
|
112
examples/lightrag_openai_compatible_demo_embedding_cache.py
Normal file
112
examples/lightrag_openai_compatible_demo_embedding_cache.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"solar-mini",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await openai_embedding(
|
||||||
|
texts,
|
||||||
|
model="solar-embedding-1-large-query",
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding_dim():
|
||||||
|
test_text = ["This is a test sentence."]
|
||||||
|
embedding = await embedding_func(test_text)
|
||||||
|
embedding_dim = embedding.shape[1]
|
||||||
|
return embedding_dim
|
||||||
|
|
||||||
|
|
||||||
|
# function test
|
||||||
|
async def test_funcs():
|
||||||
|
result = await llm_model_func("How are you?")
|
||||||
|
print("llm_model_func: ", result)
|
||||||
|
|
||||||
|
result = await embedding_func(["How are you?"])
|
||||||
|
print("embedding_func: ", result)
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
try:
|
||||||
|
embedding_dimension = await get_embedding_dim()
|
||||||
|
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
embedding_cache_config={
|
||||||
|
"enabled": True,
|
||||||
|
"similarity_threshold": 0.90,
|
||||||
|
},
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=embedding_dimension,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=embedding_func,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
await rag.ainsert(f.read())
|
||||||
|
|
||||||
|
# Perform naive search
|
||||||
|
print(
|
||||||
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform local search
|
||||||
|
print(
|
||||||
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform global search
|
||||||
|
print(
|
||||||
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode="global"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform hybrid search
|
||||||
|
print(
|
||||||
|
await rag.aquery(
|
||||||
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode="hybrid"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
@@ -85,7 +85,10 @@ class LightRAG:
|
|||||||
working_dir: str = field(
|
working_dir: str = field(
|
||||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||||
)
|
)
|
||||||
|
# Default not to use embedding cache
|
||||||
|
embedding_cache_config: dict = field(
|
||||||
|
default_factory=lambda: {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
kv_storage: str = field(default="JsonKVStorage")
|
kv_storage: str = field(default="JsonKVStorage")
|
||||||
vector_storage: str = field(default="NanoVectorDBStorage")
|
vector_storage: str = field(default="NanoVectorDBStorage")
|
||||||
graph_storage: str = field(default="NetworkXStorage")
|
graph_storage: str = field(default="NetworkXStorage")
|
||||||
|
230
lightrag/llm.py
230
lightrag/llm.py
@@ -33,6 +33,8 @@ from .utils import (
|
|||||||
compute_args_hash,
|
compute_args_hash,
|
||||||
wrap_embedding_func_with_attrs,
|
wrap_embedding_func_with_attrs,
|
||||||
locate_json_string_body_from_string,
|
locate_json_string_body_from_string,
|
||||||
|
quantize_embedding,
|
||||||
|
get_best_cached_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -65,6 +67,25 @@ async def openai_complete_if_cache(
|
|||||||
messages.extend(history_messages)
|
messages.extend(history_messages)
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
|
# Get embedding cache configuration
|
||||||
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
|
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||||
|
if is_embedding_cache_enabled:
|
||||||
|
# Use embedding cache
|
||||||
|
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||||
|
current_embedding = await embedding_model_func([prompt])
|
||||||
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
|
best_cached_response = await get_best_cached_response(
|
||||||
|
hashing_kv,
|
||||||
|
current_embedding[0],
|
||||||
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
|
)
|
||||||
|
if best_cached_response is not None:
|
||||||
|
return best_cached_response
|
||||||
|
else:
|
||||||
|
# Use regular cache
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
@@ -81,10 +102,24 @@ async def openai_complete_if_cache(
|
|||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if r"\u" in content:
|
if r"\u" in content:
|
||||||
content = content.encode("utf-8").decode("unicode_escape")
|
content = content.encode("utf-8").decode("unicode_escape")
|
||||||
# print(content)
|
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert(
|
await hashing_kv.upsert(
|
||||||
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
{
|
||||||
|
args_hash: {
|
||||||
|
"return": content,
|
||||||
|
"model": model,
|
||||||
|
"embedding": quantized.tobytes().hex()
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_shape": quantized.shape
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||||
|
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@@ -125,6 +160,24 @@ async def azure_openai_complete_if_cache(
|
|||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
|
# Get embedding cache configuration
|
||||||
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
|
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||||
|
if is_embedding_cache_enabled:
|
||||||
|
# Use embedding cache
|
||||||
|
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||||
|
current_embedding = await embedding_model_func([prompt])
|
||||||
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
|
best_cached_response = await get_best_cached_response(
|
||||||
|
hashing_kv,
|
||||||
|
current_embedding[0],
|
||||||
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
|
)
|
||||||
|
if best_cached_response is not None:
|
||||||
|
return best_cached_response
|
||||||
|
else:
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
@@ -136,7 +189,21 @@ async def azure_openai_complete_if_cache(
|
|||||||
|
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert(
|
await hashing_kv.upsert(
|
||||||
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
{
|
||||||
|
args_hash: {
|
||||||
|
"return": response.choices[0].message.content,
|
||||||
|
"model": model,
|
||||||
|
"embedding": quantized.tobytes().hex()
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_shape": quantized.shape
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||||
|
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
@@ -204,6 +271,25 @@ async def bedrock_complete_if_cache(
|
|||||||
|
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
|
# Get embedding cache configuration
|
||||||
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
|
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||||
|
if is_embedding_cache_enabled:
|
||||||
|
# Use embedding cache
|
||||||
|
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||||
|
current_embedding = await embedding_model_func([prompt])
|
||||||
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
|
best_cached_response = await get_best_cached_response(
|
||||||
|
hashing_kv,
|
||||||
|
current_embedding[0],
|
||||||
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
|
)
|
||||||
|
if best_cached_response is not None:
|
||||||
|
return best_cached_response
|
||||||
|
else:
|
||||||
|
# Use regular cache
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
@@ -223,6 +309,19 @@ async def bedrock_complete_if_cache(
|
|||||||
args_hash: {
|
args_hash: {
|
||||||
"return": response["output"]["message"]["content"][0]["text"],
|
"return": response["output"]["message"]["content"][0]["text"],
|
||||||
"model": model,
|
"model": model,
|
||||||
|
"embedding": quantized.tobytes().hex()
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_shape": quantized.shape
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_min": min_val
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_max": max_val
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"original_prompt": prompt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -245,7 +344,11 @@ def initialize_hf_model(model_name):
|
|||||||
|
|
||||||
|
|
||||||
async def hf_model_if_cache(
|
async def hf_model_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = model
|
model_name = model
|
||||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||||
@@ -257,10 +360,30 @@ async def hf_model_if_cache(
|
|||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
|
# Get embedding cache configuration
|
||||||
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
|
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||||
|
if is_embedding_cache_enabled:
|
||||||
|
# Use embedding cache
|
||||||
|
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||||
|
current_embedding = await embedding_model_func([prompt])
|
||||||
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
|
best_cached_response = await get_best_cached_response(
|
||||||
|
hashing_kv,
|
||||||
|
current_embedding[0],
|
||||||
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
|
)
|
||||||
|
if best_cached_response is not None:
|
||||||
|
return best_cached_response
|
||||||
|
else:
|
||||||
|
# Use regular cache
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
return if_cache_return["return"]
|
return if_cache_return["return"]
|
||||||
|
|
||||||
input_prompt = ""
|
input_prompt = ""
|
||||||
try:
|
try:
|
||||||
input_prompt = hf_tokenizer.apply_chat_template(
|
input_prompt = hf_tokenizer.apply_chat_template(
|
||||||
@@ -305,12 +428,32 @@ async def hf_model_if_cache(
|
|||||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||||
)
|
)
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
await hashing_kv.upsert(
|
||||||
|
{
|
||||||
|
args_hash: {
|
||||||
|
"return": response_text,
|
||||||
|
"model": model,
|
||||||
|
"embedding": quantized.tobytes().hex()
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_shape": quantized.shape
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||||
|
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
return response_text
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
async def ollama_model_if_cache(
|
async def ollama_model_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
kwargs.pop("max_tokens", None)
|
kwargs.pop("max_tokens", None)
|
||||||
# kwargs.pop("response_format", None) # allow json
|
# kwargs.pop("response_format", None) # allow json
|
||||||
@@ -326,6 +469,25 @@ async def ollama_model_if_cache(
|
|||||||
messages.extend(history_messages)
|
messages.extend(history_messages)
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
|
# Get embedding cache configuration
|
||||||
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
|
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||||
|
if is_embedding_cache_enabled:
|
||||||
|
# Use embedding cache
|
||||||
|
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||||
|
current_embedding = await embedding_model_func([prompt])
|
||||||
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
|
best_cached_response = await get_best_cached_response(
|
||||||
|
hashing_kv,
|
||||||
|
current_embedding[0],
|
||||||
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
|
)
|
||||||
|
if best_cached_response is not None:
|
||||||
|
return best_cached_response
|
||||||
|
else:
|
||||||
|
# Use regular cache
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
@@ -336,8 +498,23 @@ async def ollama_model_if_cache(
|
|||||||
result = response["message"]["content"]
|
result = response["message"]["content"]
|
||||||
|
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
|
await hashing_kv.upsert(
|
||||||
|
{
|
||||||
|
args_hash: {
|
||||||
|
"return": result,
|
||||||
|
"model": model,
|
||||||
|
"embedding": quantized.tobytes().hex()
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_shape": quantized.shape
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||||
|
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -444,6 +621,25 @@ async def lmdeploy_model_if_cache(
|
|||||||
messages.extend(history_messages)
|
messages.extend(history_messages)
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
|
# Get embedding cache configuration
|
||||||
|
embedding_cache_config = hashing_kv.global_config.get(
|
||||||
|
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||||
|
)
|
||||||
|
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||||
|
if is_embedding_cache_enabled:
|
||||||
|
# Use embedding cache
|
||||||
|
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||||
|
current_embedding = await embedding_model_func([prompt])
|
||||||
|
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||||
|
best_cached_response = await get_best_cached_response(
|
||||||
|
hashing_kv,
|
||||||
|
current_embedding[0],
|
||||||
|
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||||
|
)
|
||||||
|
if best_cached_response is not None:
|
||||||
|
return best_cached_response
|
||||||
|
else:
|
||||||
|
# Use regular cache
|
||||||
args_hash = compute_args_hash(model, messages)
|
args_hash = compute_args_hash(model, messages)
|
||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
@@ -466,7 +662,23 @@ async def lmdeploy_model_if_cache(
|
|||||||
response += res.response
|
response += res.response
|
||||||
|
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
|
await hashing_kv.upsert(
|
||||||
|
{
|
||||||
|
args_hash: {
|
||||||
|
"return": response,
|
||||||
|
"model": model,
|
||||||
|
"embedding": quantized.tobytes().hex()
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_shape": quantized.shape
|
||||||
|
if is_embedding_cache_enabled
|
||||||
|
else None,
|
||||||
|
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||||
|
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@@ -307,3 +307,72 @@ def process_combine_contexts(hl, ll):
|
|||||||
combined_sources_result = "\n".join(combined_sources_result)
|
combined_sources_result = "\n".join(combined_sources_result)
|
||||||
|
|
||||||
return combined_sources_result
|
return combined_sources_result
|
||||||
|
|
||||||
|
|
||||||
|
async def get_best_cached_response(
|
||||||
|
hashing_kv, current_embedding, similarity_threshold=0.95
|
||||||
|
):
|
||||||
|
"""Get the cached response with the highest similarity"""
|
||||||
|
try:
|
||||||
|
# Get all keys
|
||||||
|
all_keys = await hashing_kv.all_keys()
|
||||||
|
max_similarity = 0
|
||||||
|
best_cached_response = None
|
||||||
|
|
||||||
|
# Get cached data one by one
|
||||||
|
for key in all_keys:
|
||||||
|
cache_data = await hashing_kv.get_by_id(key)
|
||||||
|
if cache_data is None or "embedding" not in cache_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert cached embedding list to ndarray
|
||||||
|
cached_quantized = np.frombuffer(
|
||||||
|
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
||||||
|
).reshape(cache_data["embedding_shape"])
|
||||||
|
cached_embedding = dequantize_embedding(
|
||||||
|
cached_quantized,
|
||||||
|
cache_data["embedding_min"],
|
||||||
|
cache_data["embedding_max"],
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity = cosine_similarity(current_embedding, cached_embedding)
|
||||||
|
if similarity > max_similarity:
|
||||||
|
max_similarity = similarity
|
||||||
|
best_cached_response = cache_data["return"]
|
||||||
|
|
||||||
|
if max_similarity > similarity_threshold:
|
||||||
|
return best_cached_response
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error in get_best_cached_response: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(v1, v2):
|
||||||
|
"""Calculate cosine similarity between two vectors"""
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
norm1 = np.linalg.norm(v1)
|
||||||
|
norm2 = np.linalg.norm(v2)
|
||||||
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
|
||||||
|
"""Quantize embedding to specified bits"""
|
||||||
|
# Calculate min/max values for reconstruction
|
||||||
|
min_val = embedding.min()
|
||||||
|
max_val = embedding.max()
|
||||||
|
|
||||||
|
# Quantize to 0-255 range
|
||||||
|
scale = (2**bits - 1) / (max_val - min_val)
|
||||||
|
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
|
||||||
|
|
||||||
|
return quantized, min_val, max_val
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize_embedding(
|
||||||
|
quantized: np.ndarray, min_val: float, max_val: float, bits=8
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Restore quantized embedding"""
|
||||||
|
scale = (max_val - min_val) / (2**bits - 1)
|
||||||
|
return (quantized * scale + min_val).astype(np.float32)
|
||||||
|
Reference in New Issue
Block a user