519
lightrag/llm.py
519
lightrag/llm.py
@@ -4,8 +4,8 @@ import json
|
||||
import os
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from typing import List, Dict, Callable, Any, Union
|
||||
|
||||
from typing import List, Dict, Callable, Any, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -66,39 +66,22 @@ async def openai_complete_if_cache(
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
if "response_format" in kwargs:
|
||||
response = await openai_async_client.beta.chat.completions.parse(
|
||||
@@ -112,24 +95,21 @@ async def openai_complete_if_cache(
|
||||
if r"\u" in content:
|
||||
content = content.encode("utf-8").decode("unicode_escape")
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": content,
|
||||
"model": model,
|
||||
"embedding": quantized.tobytes().hex()
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_shape": quantized.shape
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||
"original_prompt": prompt,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=content,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@@ -162,6 +142,8 @@ async def azure_openai_complete_if_cache(
|
||||
)
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
mode = kwargs.pop("mode", "default")
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
@@ -169,56 +151,35 @@ async def azure_openai_complete_if_cache(
|
||||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": response.choices[0].message.content,
|
||||
"model": model,
|
||||
"embedding": quantized.tobytes().hex()
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_shape": quantized.shape
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||
"original_prompt": prompt,
|
||||
}
|
||||
}
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=content,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
@@ -259,6 +220,15 @@ async def bedrock_complete_if_cache(
|
||||
|
||||
# Add user prompt
|
||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# Initialize Converse API arguments
|
||||
args = {"modelId": model, "messages": messages}
|
||||
@@ -281,34 +251,15 @@ async def bedrock_complete_if_cache(
|
||||
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||
kwargs.pop(param)
|
||||
)
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# Call model via Converse API
|
||||
session = aioboto3.Session()
|
||||
@@ -318,30 +269,22 @@ async def bedrock_complete_if_cache(
|
||||
except Exception as e:
|
||||
raise BedrockError(e)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": response["output"]["message"]["content"][0]["text"],
|
||||
"model": model,
|
||||
"embedding": quantized.tobytes().hex()
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_shape": quantized.shape
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_min": min_val
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_max": max_val
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"original_prompt": prompt,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response["output"]["message"]["content"][0]["text"],
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -379,32 +322,14 @@ async def hf_model_if_cache(
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
input_prompt = ""
|
||||
try:
|
||||
@@ -449,24 +374,22 @@ async def hf_model_if_cache(
|
||||
response_text = hf_tokenizer.decode(
|
||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
)
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": response_text,
|
||||
"model": model,
|
||||
"embedding": quantized.tobytes().hex()
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_shape": quantized.shape
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||
"original_prompt": prompt,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response_text,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
@@ -497,32 +420,14 @@ async def ollama_model_if_cache(
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
if stream:
|
||||
@@ -535,29 +440,39 @@ async def ollama_model_if_cache(
|
||||
return inner()
|
||||
else:
|
||||
result = response["message"]["content"]
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": result,
|
||||
"model": model,
|
||||
"embedding": quantized.tobytes().hex()
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_shape": quantized.shape
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_min": min_val
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_max": max_val
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"original_prompt": prompt,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=result,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
return result
|
||||
result = response["message"]["content"]
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=result,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -668,32 +583,14 @@ async def lmdeploy_model_if_cache(
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
@@ -711,24 +608,21 @@ async def lmdeploy_model_if_cache(
|
||||
):
|
||||
response += res.response
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": response,
|
||||
"model": model,
|
||||
"embedding": quantized.tobytes().hex()
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_shape": quantized.shape
|
||||
if is_embedding_cache_enabled
|
||||
else None,
|
||||
"embedding_min": min_val if is_embedding_cache_enabled else None,
|
||||
"embedding_max": max_val if is_embedding_cache_enabled else None,
|
||||
"original_prompt": prompt,
|
||||
}
|
||||
}
|
||||
)
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -1158,6 +1052,75 @@ class MultiModel:
|
||||
return await next_model.gen_func(**args)
|
||||
|
||||
|
||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
"""Generic cache handling function"""
|
||||
if hashing_kv is None:
|
||||
return None, None, None, None
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
mode=mode,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
# Use regular cache
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
|
||||
return None, quantized, min_val, max_val
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheData:
|
||||
args_hash: str
|
||||
content: str
|
||||
model: str
|
||||
prompt: str
|
||||
quantized: Optional[np.ndarray] = None
|
||||
min_val: Optional[float] = None
|
||||
max_val: Optional[float] = None
|
||||
mode: str = "default"
|
||||
|
||||
|
||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
if hashing_kv is None:
|
||||
return
|
||||
|
||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
"return": cache_data.content,
|
||||
"model": cache_data.model,
|
||||
"embedding": cache_data.quantized.tobytes().hex()
|
||||
if cache_data.quantized is not None
|
||||
else None,
|
||||
"embedding_shape": cache_data.quantized.shape
|
||||
if cache_data.quantized is not None
|
||||
else None,
|
||||
"embedding_min": cache_data.min_val,
|
||||
"embedding_max": cache_data.max_val,
|
||||
"original_prompt": cache_data.prompt,
|
||||
}
|
||||
|
||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
@@ -474,7 +474,9 @@ async def kg_query(
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
result = await use_model_func(
|
||||
kw_prompt, keyword_extraction=True, mode=query_param.mode
|
||||
)
|
||||
logger.info("kw_prompt result:")
|
||||
print(result)
|
||||
try:
|
||||
@@ -535,6 +537,7 @@ async def kg_query(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
mode=query_param.mode,
|
||||
)
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
response = (
|
||||
@@ -1036,6 +1039,7 @@ async def naive_query(
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
mode=query_param.mode,
|
||||
)
|
||||
|
||||
if len(response) > len(sys_prompt):
|
||||
|
@@ -310,43 +310,57 @@ def process_combine_contexts(hl, ll):
|
||||
|
||||
|
||||
async def get_best_cached_response(
|
||||
hashing_kv, current_embedding, similarity_threshold=0.95
|
||||
):
|
||||
"""Get the cached response with the highest similarity"""
|
||||
try:
|
||||
# Get all keys
|
||||
all_keys = await hashing_kv.all_keys()
|
||||
max_similarity = 0
|
||||
best_cached_response = None
|
||||
|
||||
# Get cached data one by one
|
||||
for key in all_keys:
|
||||
cache_data = await hashing_kv.get_by_id(key)
|
||||
if cache_data is None or "embedding" not in cache_data:
|
||||
continue
|
||||
|
||||
# Convert cached embedding list to ndarray
|
||||
cached_quantized = np.frombuffer(
|
||||
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
||||
).reshape(cache_data["embedding_shape"])
|
||||
cached_embedding = dequantize_embedding(
|
||||
cached_quantized,
|
||||
cache_data["embedding_min"],
|
||||
cache_data["embedding_max"],
|
||||
)
|
||||
|
||||
similarity = cosine_similarity(current_embedding, cached_embedding)
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_cached_response = cache_data["return"]
|
||||
|
||||
if max_similarity > similarity_threshold:
|
||||
return best_cached_response
|
||||
hashing_kv,
|
||||
current_embedding,
|
||||
similarity_threshold=0.95,
|
||||
mode="default",
|
||||
) -> Union[str, None]:
|
||||
# Get mode-specific cache
|
||||
mode_cache = await hashing_kv.get_by_id(mode)
|
||||
if not mode_cache:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in get_best_cached_response: {e}")
|
||||
return None
|
||||
best_similarity = -1
|
||||
best_response = None
|
||||
best_prompt = None
|
||||
best_cache_id = None
|
||||
|
||||
# Only iterate through cache entries for this mode
|
||||
for cache_id, cache_data in mode_cache.items():
|
||||
if cache_data["embedding"] is None:
|
||||
continue
|
||||
|
||||
# Convert cached embedding list to ndarray
|
||||
cached_quantized = np.frombuffer(
|
||||
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
||||
).reshape(cache_data["embedding_shape"])
|
||||
cached_embedding = dequantize_embedding(
|
||||
cached_quantized,
|
||||
cache_data["embedding_min"],
|
||||
cache_data["embedding_max"],
|
||||
)
|
||||
|
||||
similarity = cosine_similarity(current_embedding, cached_embedding)
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_response = cache_data["return"]
|
||||
best_prompt = cache_data["original_prompt"]
|
||||
best_cache_id = cache_id
|
||||
|
||||
if best_similarity > similarity_threshold:
|
||||
prompt_display = (
|
||||
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
||||
)
|
||||
log_data = {
|
||||
"event": "cache_hit",
|
||||
"mode": mode,
|
||||
"similarity": round(best_similarity, 4),
|
||||
"cache_id": best_cache_id,
|
||||
"original_prompt": prompt_display,
|
||||
}
|
||||
logger.info(json.dumps(log_data, ensure_ascii=False))
|
||||
return best_response
|
||||
return None
|
||||
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
|
Reference in New Issue
Block a user