update do_preprocess

This commit is contained in:
tackhwa
2024-10-26 16:24:35 +08:00
parent 2e703296d5
commit 2cf3a85a0f

View File

@@ -286,7 +286,9 @@ async def hf_model_if_cache(
output = hf_model.generate( output = hf_model.generate(
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
) )
response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text return response_text
@@ -323,19 +325,38 @@ async def ollama_model_if_cache(
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(model, tp=1, chat_template=None, log_level='WARNING', model_format='hf', quant_policy=0): def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
lmdeploy_pipe = pipeline( lmdeploy_pipe = pipeline(
model_path=model, model_path=model,
backend_config=TurbomindEngineConfig(tp=tp, model_format=model_format, quant_policy=quant_policy), backend_config=TurbomindEngineConfig(
chat_template_config=ChatTemplateConfig(model_name=chat_template) if chat_template else None, tp=tp, model_format=model_format, quant_policy=quant_policy
log_level='WARNING') ),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
log_level="WARNING",
)
return lmdeploy_pipe return lmdeploy_pipe
async def lmdeploy_model_if_cache( async def lmdeploy_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], model,
chat_template=None, model_format='hf',quant_policy=0, **kwargs prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str: ) -> str:
""" """
Args: Args:
@@ -354,36 +375,37 @@ async def lmdeploy_model_if_cache(
and so on. and so on.
chat_template (str): needed when model is a pytorch model on chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b", huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF. and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed. prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied. True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False. in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied. Default to be False, which means greedy decoding will be applied.
""" """
try: try:
import lmdeploy import lmdeploy
from lmdeploy import version_info, GenerationConfig from lmdeploy import version_info, GenerationConfig
except: except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
kwargs.pop("response_format", None) kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512) max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop('tp', 1) tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop('skip_special_tokens', False) skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop('do_preprocess', True) do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop('do_sample', False) do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs gen_params = kwargs
version = version_info version = version_info
if do_sample is not None and version < (0, 6, 0): if do_sample is not None and version < (0, 6, 0):
raise RuntimeError( raise RuntimeError(
'`do_sample` parameter is not supported by lmdeploy until ' "`do_sample` parameter is not supported by lmdeploy until "
f'v0.6.0, but currently using lmdeloy {lmdeploy.__version__}') f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else: else:
do_sample = True do_sample = True
gen_params.update(do_sample=do_sample) gen_params.update(do_sample=do_sample)
@@ -394,7 +416,8 @@ async def lmdeploy_model_if_cache(
chat_template=chat_template, chat_template=chat_template,
model_format=model_format, model_format=model_format,
quant_policy=quant_policy, quant_policy=quant_policy,
log_level='WARNING') log_level="WARNING",
)
messages = [] messages = []
if system_prompt: if system_prompt:
@@ -410,11 +433,19 @@ async def lmdeploy_model_if_cache(
return if_cache_return["return"] return if_cache_return["return"]
gen_config = GenerationConfig( gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, **gen_params) skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
response = "" response = ""
async for res in lmdeploy_pipe.generate(messages, gen_config=gen_config, async for res in lmdeploy_pipe.generate(
do_preprocess=do_preprocess, stream_response=False, session_id=1): messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response response += res.response
if hashing_kv is not None: if hashing_kv is not None: