diff --git a/lightrag/llm.py b/lightrag/llm.py index 76adec26..4dcf535c 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -207,6 +207,8 @@ async def bedrock_complete_if_cache( def initialize_hf_model(model_name): hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True) hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True) + if hf_tokenizer.pad_token is None: + hf_tokenizer.pad_token = hf_tokenizer.eos_token return hf_model, hf_tokenizer @@ -216,9 +218,6 @@ async def hf_model_if_cache( ) -> str: model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) - if hf_tokenizer.pad_token is None: - # print("use eos token") - hf_tokenizer.pad_token = hf_tokenizer.eos_token hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: