From 8deb30aa205de458b727a24345f8b4a511dabafc Mon Sep 17 00:00:00 2001 From: tackhwa Date: Sat, 26 Oct 2024 16:11:15 +0800 Subject: [PATCH 1/3] support lmdeploy backend --- examples/lightrag_lmdeploy_demo.py | 74 +++++++++++++++++++++ lightrag/llm.py | 100 +++++++++++++++++++++++++++++ requirements.txt | 1 + 3 files changed, 175 insertions(+) create mode 100644 examples/lightrag_lmdeploy_demo.py diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py new file mode 100644 index 00000000..ea7ace0e --- /dev/null +++ b/examples/lightrag_lmdeploy_demo.py @@ -0,0 +1,74 @@ +import os + +from lightrag import LightRAG, QueryParam +from lightrag.llm import lmdeploy_model_if_cache, hf_embedding +from lightrag.utils import EmbeddingFunc +from transformers import AutoModel, AutoTokenizer + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +async def lmdeploy_model_complete( + prompt=None, system_prompt=None, history_messages=[], **kwargs +) -> str: + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + return await lmdeploy_model_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + ## please specify chat_template if your local path does not follow original HF file name, + ## or model_name is a pytorch model on huggingface.co, + ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py + ## for a list of chat_template available in lmdeploy. + chat_template = "llama3", + # model_format ='awq', # if you are using awq quantization model. + # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8. + **kwargs, + ) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=lmdeploy_model_complete, + llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=5000, + func=lambda texts: hf_embedding( + texts, + tokenizer=AutoTokenizer.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ), + embed_model=AutoModel.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ), + ), + ), +) + + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/lightrag/llm.py b/lightrag/llm.py index bb0d6063..028084bd 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -322,6 +322,106 @@ async def ollama_model_if_cache( return result +@lru_cache(maxsize=1) +def initialize_lmdeploy_pipeline(model, tp=1, chat_template=None, log_level='WARNING', model_format='hf', quant_policy=0): + from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig + lmdeploy_pipe = pipeline( + model_path=model, + backend_config=TurbomindEngineConfig(tp=tp, model_format=model_format, quant_policy=quant_policy), + chat_template_config=ChatTemplateConfig(model_name=chat_template) if chat_template else None, + log_level='WARNING') + return lmdeploy_pipe + + +async def lmdeploy_model_if_cache( + model, prompt, system_prompt=None, history_messages=[], + chat_template=None, model_format='hf',quant_policy=0, **kwargs +) -> str: + """ + Args: + model (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download + from ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + chat_template (str): needed when model is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, + and when the model name of local path did not match the original model name in HF. + tp (int): tensor parallel + prompt (Union[str, List[str]]): input texts to be completed. + do_preprocess (bool): whether pre-process the messages. Default to + True, which means chat_template will be applied. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be False. + do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. + Default to be False, which means greedy decoding will be applied. + """ + try: + import lmdeploy + from lmdeploy import version_info, GenerationConfig + except: + raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") + + kwargs.pop("response_format", None) + max_new_tokens = kwargs.pop("max_tokens", 512) + tp = kwargs.pop('tp', 1) + skip_special_tokens = kwargs.pop('skip_special_tokens', False) + do_preprocess = kwargs.pop('do_preprocess', True) + do_sample = kwargs.pop('do_sample', False) + gen_params = kwargs + + version = version_info + if do_sample is not None and version < (0, 6, 0): + raise RuntimeError( + '`do_sample` parameter is not supported by lmdeploy until ' + f'v0.6.0, but currently using lmdeloy {lmdeploy.__version__}') + else: + do_sample = True + gen_params.update(do_sample=do_sample) + + lmdeploy_pipe = initialize_lmdeploy_pipeline( + model=model, + tp=tp, + chat_template=chat_template, + model_format=model_format, + quant_policy=quant_policy, + log_level='WARNING') + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + gen_config = GenerationConfig( + skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, **gen_params) + + response = "" + async for res in lmdeploy_pipe.generate(messages, gen_config=gen_config, + do_preprocess=do_preprocess, stream_response=False, session_id=1): + response += res.response + + if hashing_kv is not None: + await hashing_kv.upsert({args_hash: {"return": response, "model": model}}) + return response + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: diff --git a/requirements.txt b/requirements.txt index 98f32b0a..6b0e025a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ tiktoken torch transformers xxhash +# lmdeploy[all] From 2e703296d5e9f4a15547c1d1be3ecb53eab1925c Mon Sep 17 00:00:00 2001 From: tackhwa Date: Sat, 26 Oct 2024 16:13:18 +0800 Subject: [PATCH 2/3] pre-commit --- examples/lightrag_lmdeploy_demo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py index ea7ace0e..aeb96f71 100644 --- a/examples/lightrag_lmdeploy_demo.py +++ b/examples/lightrag_lmdeploy_demo.py @@ -10,10 +10,11 @@ WORKING_DIR = "./dickens" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def lmdeploy_model_complete( prompt=None, system_prompt=None, history_messages=[], **kwargs ) -> str: - model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await lmdeploy_model_if_cache( model_name, prompt, @@ -23,7 +24,7 @@ async def lmdeploy_model_complete( ## or model_name is a pytorch model on huggingface.co, ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py ## for a list of chat_template available in lmdeploy. - chat_template = "llama3", + chat_template="llama3", # model_format ='awq', # if you are using awq quantization model. # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8. **kwargs, @@ -33,7 +34,7 @@ async def lmdeploy_model_complete( rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=lmdeploy_model_complete, - llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model + llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model embedding_func=EmbeddingFunc( embedding_dim=384, max_token_size=5000, From 2cf3a85a0f09094372ae632cfce95cf1f649de76 Mon Sep 17 00:00:00 2001 From: tackhwa Date: Sat, 26 Oct 2024 16:24:35 +0800 Subject: [PATCH 3/3] update do_preprocess --- lightrag/llm.py | 77 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 23 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 028084bd..d86886ea 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -286,7 +286,9 @@ async def hf_model_if_cache( output = hf_model.generate( **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True ) - response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) + response_text = hf_tokenizer.decode( + output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True + ) if hashing_kv is not None: await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) return response_text @@ -323,19 +325,38 @@ async def ollama_model_if_cache( @lru_cache(maxsize=1) -def initialize_lmdeploy_pipeline(model, tp=1, chat_template=None, log_level='WARNING', model_format='hf', quant_policy=0): +def initialize_lmdeploy_pipeline( + model, + tp=1, + chat_template=None, + log_level="WARNING", + model_format="hf", + quant_policy=0, +): from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig + lmdeploy_pipe = pipeline( model_path=model, - backend_config=TurbomindEngineConfig(tp=tp, model_format=model_format, quant_policy=quant_policy), - chat_template_config=ChatTemplateConfig(model_name=chat_template) if chat_template else None, - log_level='WARNING') + backend_config=TurbomindEngineConfig( + tp=tp, model_format=model_format, quant_policy=quant_policy + ), + chat_template_config=ChatTemplateConfig(model_name=chat_template) + if chat_template + else None, + log_level="WARNING", + ) return lmdeploy_pipe async def lmdeploy_model_if_cache( - model, prompt, system_prompt=None, history_messages=[], - chat_template=None, model_format='hf',quant_policy=0, **kwargs + model, + prompt, + system_prompt=None, + history_messages=[], + chat_template=None, + model_format="hf", + quant_policy=0, + **kwargs, ) -> str: """ Args: @@ -354,36 +375,37 @@ async def lmdeploy_model_if_cache( and so on. chat_template (str): needed when model is a pytorch model on huggingface.co, such as "internlm-chat-7b", - "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, and when the model name of local path did not match the original model name in HF. tp (int): tensor parallel prompt (Union[str, List[str]]): input texts to be completed. do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be False. - do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. + in the decoding. Default to be True. + do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. Default to be False, which means greedy decoding will be applied. """ try: import lmdeploy from lmdeploy import version_info, GenerationConfig - except: + except Exception: raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") - + kwargs.pop("response_format", None) max_new_tokens = kwargs.pop("max_tokens", 512) - tp = kwargs.pop('tp', 1) - skip_special_tokens = kwargs.pop('skip_special_tokens', False) - do_preprocess = kwargs.pop('do_preprocess', True) - do_sample = kwargs.pop('do_sample', False) + tp = kwargs.pop("tp", 1) + skip_special_tokens = kwargs.pop("skip_special_tokens", True) + do_preprocess = kwargs.pop("do_preprocess", True) + do_sample = kwargs.pop("do_sample", False) gen_params = kwargs - + version = version_info if do_sample is not None and version < (0, 6, 0): raise RuntimeError( - '`do_sample` parameter is not supported by lmdeploy until ' - f'v0.6.0, but currently using lmdeloy {lmdeploy.__version__}') + "`do_sample` parameter is not supported by lmdeploy until " + f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" + ) else: do_sample = True gen_params.update(do_sample=do_sample) @@ -394,7 +416,8 @@ async def lmdeploy_model_if_cache( chat_template=chat_template, model_format=model_format, quant_policy=quant_policy, - log_level='WARNING') + log_level="WARNING", + ) messages = [] if system_prompt: @@ -410,11 +433,19 @@ async def lmdeploy_model_if_cache( return if_cache_return["return"] gen_config = GenerationConfig( - skip_special_tokens=skip_special_tokens, max_new_tokens=max_new_tokens, **gen_params) + skip_special_tokens=skip_special_tokens, + max_new_tokens=max_new_tokens, + **gen_params, + ) response = "" - async for res in lmdeploy_pipe.generate(messages, gen_config=gen_config, - do_preprocess=do_preprocess, stream_response=False, session_id=1): + async for res in lmdeploy_pipe.generate( + messages, + gen_config=gen_config, + do_preprocess=do_preprocess, + stream_response=False, + session_id=1, + ): response += res.response if hashing_kv is not None: