支持多轮对话
This commit is contained in:
@@ -108,8 +108,23 @@ def convert_response_to_json(response: str) -> dict:
|
||||
raise e from None
|
||||
|
||||
|
||||
def compute_args_hash(*args):
|
||||
return md5(str(args).encode()).hexdigest()
|
||||
def compute_args_hash(*args, cache_type: str = None) -> str:
|
||||
"""Compute a hash for the given arguments.
|
||||
Args:
|
||||
*args: Arguments to hash
|
||||
cache_type: Type of cache (e.g., 'keywords', 'query')
|
||||
Returns:
|
||||
str: Hash string
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# Convert all arguments to strings and join them
|
||||
args_str = "".join([str(arg) for arg in args])
|
||||
if cache_type:
|
||||
args_str = f"{cache_type}:{args_str}"
|
||||
|
||||
# Compute MD5 hash
|
||||
return hashlib.md5(args_str.encode()).hexdigest()
|
||||
|
||||
|
||||
def compute_mdhash_id(content, prefix: str = ""):
|
||||
@@ -343,8 +358,8 @@ async def get_best_cached_response(
|
||||
use_llm_check=False,
|
||||
llm_func=None,
|
||||
original_prompt=None,
|
||||
cache_type=None,
|
||||
) -> Union[str, None]:
|
||||
# Get mode-specific cache
|
||||
mode_cache = await hashing_kv.get_by_id(mode)
|
||||
if not mode_cache:
|
||||
return None
|
||||
@@ -356,6 +371,10 @@ async def get_best_cached_response(
|
||||
|
||||
# Only iterate through cache entries for this mode
|
||||
for cache_id, cache_data in mode_cache.items():
|
||||
# Skip if cache_type doesn't match
|
||||
if cache_type and cache_data.get("cache_type") != cache_type:
|
||||
continue
|
||||
|
||||
if cache_data["embedding"] is None:
|
||||
continue
|
||||
|
||||
@@ -452,13 +471,12 @@ def dequantize_embedding(
|
||||
return (quantized * scale + min_val).astype(np.float32)
|
||||
|
||||
|
||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
|
||||
"""Generic cache handling function"""
|
||||
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
|
||||
return None, None, None, None
|
||||
|
||||
# For naive mode, only use simple cache matching
|
||||
# if mode == "naive":
|
||||
# For default mode, only use simple cache matching
|
||||
if mode == "default":
|
||||
if exists_func(hashing_kv, "get_by_mode_and_id"):
|
||||
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
|
||||
@@ -492,6 +510,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
use_llm_check=use_llm_check,
|
||||
llm_func=llm_model_func if use_llm_check else None,
|
||||
original_prompt=prompt if use_llm_check else None,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response, None, None, None
|
||||
@@ -573,3 +592,59 @@ def exists_func(obj, func_name: str) -> bool:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
|
||||
"""
|
||||
Process conversation history to get the specified number of complete turns.
|
||||
|
||||
Args:
|
||||
conversation_history: List of conversation messages in chronological order
|
||||
num_turns: Number of complete turns to include
|
||||
|
||||
Returns:
|
||||
Formatted string of the conversation history
|
||||
"""
|
||||
# Group messages into turns
|
||||
turns = []
|
||||
messages = []
|
||||
|
||||
# First, filter out keyword extraction messages
|
||||
for msg in conversation_history:
|
||||
if msg["role"] == "assistant" and (
|
||||
msg["content"].startswith('{ "high_level_keywords"')
|
||||
or msg["content"].startswith("{'high_level_keywords'")
|
||||
):
|
||||
continue
|
||||
messages.append(msg)
|
||||
|
||||
# Then process messages in chronological order
|
||||
i = 0
|
||||
while i < len(messages) - 1:
|
||||
msg1 = messages[i]
|
||||
msg2 = messages[i + 1]
|
||||
|
||||
# Check if we have a user-assistant or assistant-user pair
|
||||
if (msg1["role"] == "user" and msg2["role"] == "assistant") or (
|
||||
msg1["role"] == "assistant" and msg2["role"] == "user"
|
||||
):
|
||||
# Always put user message first in the turn
|
||||
if msg1["role"] == "assistant":
|
||||
turn = [msg2, msg1] # user, assistant
|
||||
else:
|
||||
turn = [msg1, msg2] # user, assistant
|
||||
turns.append(turn)
|
||||
i += 1
|
||||
|
||||
# Keep only the most recent num_turns
|
||||
if len(turns) > num_turns:
|
||||
turns = turns[-num_turns:]
|
||||
|
||||
# Format the turns into a string
|
||||
formatted_turns = []
|
||||
for turn in turns:
|
||||
formatted_turns.extend(
|
||||
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
|
||||
)
|
||||
|
||||
return "\n".join(formatted_turns)
|
||||
|
Reference in New Issue
Block a user