支持多轮对话

This commit is contained in:
Magic_yuan
2025-01-24 18:59:24 +08:00
parent 3d93df4049
commit 5719aa8882
5 changed files with 479 additions and 364 deletions

View File

@@ -108,8 +108,23 @@ def convert_response_to_json(response: str) -> dict:
raise e from None
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
def compute_args_hash(*args, cache_type: str = None) -> str:
"""Compute a hash for the given arguments.
Args:
*args: Arguments to hash
cache_type: Type of cache (e.g., 'keywords', 'query')
Returns:
str: Hash string
"""
import hashlib
# Convert all arguments to strings and join them
args_str = "".join([str(arg) for arg in args])
if cache_type:
args_str = f"{cache_type}:{args_str}"
# Compute MD5 hash
return hashlib.md5(args_str.encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""):
@@ -343,8 +358,8 @@ async def get_best_cached_response(
use_llm_check=False,
llm_func=None,
original_prompt=None,
cache_type=None,
) -> Union[str, None]:
# Get mode-specific cache
mode_cache = await hashing_kv.get_by_id(mode)
if not mode_cache:
return None
@@ -356,6 +371,10 @@ async def get_best_cached_response(
# Only iterate through cache entries for this mode
for cache_id, cache_data in mode_cache.items():
# Skip if cache_type doesn't match
if cache_type and cache_data.get("cache_type") != cache_type:
continue
if cache_data["embedding"] is None:
continue
@@ -452,13 +471,12 @@ def dequantize_embedding(
return (quantized * scale + min_val).astype(np.float32)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
async def handle_cache(hashing_kv, args_hash, prompt, mode="default", cache_type=None):
"""Generic cache handling function"""
if hashing_kv is None or not hashing_kv.global_config.get("enable_llm_cache"):
return None, None, None, None
# For naive mode, only use simple cache matching
# if mode == "naive":
# For default mode, only use simple cache matching
if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
@@ -492,6 +510,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
use_llm_check=use_llm_check,
llm_func=llm_model_func if use_llm_check else None,
original_prompt=prompt if use_llm_check else None,
cache_type=cache_type,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
@@ -573,3 +592,59 @@ def exists_func(obj, func_name: str) -> bool:
return True
else:
return False
def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> str:
"""
Process conversation history to get the specified number of complete turns.
Args:
conversation_history: List of conversation messages in chronological order
num_turns: Number of complete turns to include
Returns:
Formatted string of the conversation history
"""
# Group messages into turns
turns = []
messages = []
# First, filter out keyword extraction messages
for msg in conversation_history:
if msg["role"] == "assistant" and (
msg["content"].startswith('{ "high_level_keywords"')
or msg["content"].startswith("{'high_level_keywords'")
):
continue
messages.append(msg)
# Then process messages in chronological order
i = 0
while i < len(messages) - 1:
msg1 = messages[i]
msg2 = messages[i + 1]
# Check if we have a user-assistant or assistant-user pair
if (msg1["role"] == "user" and msg2["role"] == "assistant") or (
msg1["role"] == "assistant" and msg2["role"] == "user"
):
# Always put user message first in the turn
if msg1["role"] == "assistant":
turn = [msg2, msg1] # user, assistant
else:
turn = [msg1, msg2] # user, assistant
turns.append(turn)
i += 1
# Keep only the most recent num_turns
if len(turns) > num_turns:
turns = turns[-num_turns:]
# Format the turns into a string
formatted_turns = []
for turn in turns:
formatted_turns.extend(
[f"user: {turn[0]['content']}", f"assistant: {turn[1]['content']}"]
)
return "\n".join(formatted_turns)