From dfec83de1db29d485383b7397a9d3077863648f1 Mon Sep 17 00:00:00 2001 From: tackhwa Date: Wed, 23 Oct 2024 15:02:28 +0800 Subject: [PATCH] fix hf bug --- examples/lightrag_siliconcloud_demo.py | 4 ++-- lightrag/llm.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py index 8be6ae7a..82cab228 100644 --- a/examples/lightrag_siliconcloud_demo.py +++ b/examples/lightrag_siliconcloud_demo.py @@ -19,7 +19,7 @@ async def llm_model_func( prompt, system_prompt=system_prompt, history_messages=history_messages, - api_key=os.getenv("UPSTAGE_API_KEY"), + api_key=os.getenv("SILICONFLOW_API_KEY"), base_url="https://api.siliconflow.cn/v1/", **kwargs, ) @@ -29,7 +29,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray: return await siliconcloud_embedding( texts, model="netease-youdao/bce-embedding-base_v1", - api_key=os.getenv("UPSTAGE_API_KEY"), + api_key=os.getenv("SILICONFLOW_API_KEY"), max_token_size=512 ) diff --git a/lightrag/llm.py b/lightrag/llm.py index 67f547ea..76adec26 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,5 +1,6 @@ import os import copy +from functools import lru_cache import json import aioboto3 import aiohttp @@ -202,15 +203,22 @@ async def bedrock_complete_if_cache( return response["output"]["message"]["content"][0]["text"] +@lru_cache(maxsize=1) +def initialize_hf_model(model_name): + hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True) + hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True) + + return hf_model, hf_tokenizer + + async def hf_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: model_name = model - hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto") + hf_model, hf_tokenizer = initialize_hf_model(model_name) if hf_tokenizer.pad_token is None: # print("use eos token") hf_tokenizer.pad_token = hf_tokenizer.eos_token - hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: