Merge pull request #110 from tackhwa/main
[FIX] fix infinite loading hf model bug that cause oom
This commit is contained in:
@@ -19,7 +19,7 @@ async def llm_model_func(
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
base_url="https://api.siliconflow.cn/v1/",
|
||||
**kwargs,
|
||||
)
|
||||
@@ -29,7 +29,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await siliconcloud_embedding(
|
||||
texts,
|
||||
model="netease-youdao/bce-embedding-base_v1",
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
max_token_size=512
|
||||
)
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import copy
|
||||
from functools import lru_cache
|
||||
import json
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
@@ -202,15 +203,21 @@ async def bedrock_complete_if_cache(
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_hf_model(model_name):
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
|
||||
if hf_tokenizer.pad_token is None:
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
|
||||
return hf_model, hf_tokenizer
|
||||
|
||||
|
||||
async def hf_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = model
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
||||
if hf_tokenizer.pad_token is None:
|
||||
# print("use eos token")
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
|
Reference in New Issue
Block a user