Merge branch 'main' into fix-entity-name-string

This commit is contained in:
zrguo
2024-12-09 17:30:40 +08:00
committed by GitHub
8 changed files with 406 additions and 289 deletions

View File

@@ -4,8 +4,7 @@ import json
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any, Union, Optional
from dataclasses import dataclass
from typing import List, Dict, Callable, Any, Union
import aioboto3
import aiohttp
import numpy as np
@@ -27,13 +26,9 @@ from tenacity import (
)
from transformers import AutoTokenizer, AutoModelForCausalLM
from .base import BaseKVStorage
from .utils import (
compute_args_hash,
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
quantize_embedding,
get_best_cached_response,
)
import sys
@@ -66,23 +61,13 @@ async def openai_complete_if_cache(
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
@@ -95,21 +80,6 @@ async def openai_complete_if_cache(
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return content
@@ -140,10 +110,7 @@ async def azure_openai_complete_if_cache(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
mode = kwargs.pop("mode", "default")
kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@@ -151,34 +118,11 @@ async def azure_openai_complete_if_cache(
if prompt is not None:
messages.append({"role": "user", "content": prompt})
# Handle cache
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return content
@@ -210,7 +154,7 @@ async def bedrock_complete_if_cache(
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
kwargs.pop("hashing_kv", None)
# Fix message history format
messages = []
for history_message in history_messages:
@@ -220,15 +164,6 @@ async def bedrock_complete_if_cache(
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
@@ -251,15 +186,6 @@ async def bedrock_complete_if_cache(
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
# Call model via Converse API
session = aioboto3.Session()
@@ -269,21 +195,6 @@ async def bedrock_complete_if_cache(
except Exception as e:
raise BedrockError(e)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response["output"]["message"]["content"][0]["text"],
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response["output"]["message"]["content"][0]["text"]
@@ -315,22 +226,12 @@ async def hf_model_if_cache(
) -> str:
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
kwargs.pop("hashing_kv", None)
input_prompt = ""
try:
input_prompt = hf_tokenizer.apply_chat_template(
@@ -375,21 +276,6 @@ async def hf_model_if_cache(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response_text,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response_text
@@ -410,25 +296,14 @@ async def ollama_model_if_cache(
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
kwargs.pop("hashing_kv", None)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """
@@ -439,40 +314,7 @@ async def ollama_model_if_cache(
return inner()
else:
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
return response["message"]["content"]
@lru_cache(maxsize=1)
@@ -547,7 +389,7 @@ async def lmdeploy_model_if_cache(
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
kwargs.pop("hashing_kv", None)
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
@@ -579,19 +421,9 @@ async def lmdeploy_model_if_cache(
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Handle cache
mode = kwargs.pop("mode", "default")
args_hash = compute_args_hash(model, messages)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, prompt, mode
)
if cached_response is not None:
return cached_response
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
@@ -607,22 +439,6 @@ async def lmdeploy_model_if_cache(
session_id=1,
):
response += res.response
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return response
@@ -767,6 +583,39 @@ async def openai_embedding(
return np.array([dp.embedding for dp in response.data])
async def fetch_data(url, headers, data):
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
response_json = await response.json()
data_list = response_json.get("data", [])
return data_list
async def jina_embedding(
texts: list[str],
dimensions: int = 1024,
late_chunking: bool = False,
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["JINA_API_KEY"] = api_key
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
}
data = {
"model": "jina-embeddings-v3",
"normalized": True,
"embedding_type": "float",
"dimensions": f"{dimensions}",
"late_chunking": late_chunking,
"input": texts,
}
data_list = await fetch_data(url, headers, data)
return np.array([dp["embedding"] for dp in data_list])
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
@retry(
stop=stop_after_attempt(3),
@@ -1052,75 +901,6 @@ class MultiModel:
return await next_model.gen_func(**args)
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
"""Generic cache handling function"""
if hashing_kv is None:
return None, None, None, None
# Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
)
is_embedding_cache_enabled = embedding_cache_config["enabled"]
quantized = min_val = max_val = None
if is_embedding_cache_enabled:
# Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
current_embedding = await embedding_model_func([prompt])
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
best_cached_response = await get_best_cached_response(
hashing_kv,
current_embedding[0],
similarity_threshold=embedding_cache_config["similarity_threshold"],
mode=mode,
)
if best_cached_response is not None:
return best_cached_response, None, None, None
else:
# Use regular cache
mode_cache = await hashing_kv.get_by_id(mode) or {}
if args_hash in mode_cache:
return mode_cache[args_hash]["return"], None, None, None
return None, quantized, min_val, max_val
@dataclass
class CacheData:
args_hash: str
content: str
model: str
prompt: str
quantized: Optional[np.ndarray] = None
min_val: Optional[float] = None
max_val: Optional[float] = None
mode: str = "default"
async def save_to_cache(hashing_kv, cache_data: CacheData):
if hashing_kv is None:
return
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"model": cache_data.model,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,
}
await hashing_kv.upsert({cache_data.mode: mode_cache})
if __name__ == "__main__":
import asyncio