fix hf bug

This commit is contained in:
tackhwa
2024-10-23 15:02:28 +08:00
parent 5972958e79
commit 63c0283514
2 changed files with 12 additions and 4 deletions

View File

@@ -19,7 +19,7 @@ async def llm_model_func(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
api_key=os.getenv("SILICONFLOW_API_KEY"),
base_url="https://api.siliconflow.cn/v1/",
**kwargs,
)
@@ -29,7 +29,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
return await siliconcloud_embedding(
texts,
model="netease-youdao/bce-embedding-base_v1",
api_key=os.getenv("UPSTAGE_API_KEY"),
api_key=os.getenv("SILICONFLOW_API_KEY"),
max_token_size=512
)

View File

@@ -1,5 +1,6 @@
import os
import copy
from functools import lru_cache
import json
import aioboto3
import aiohttp
@@ -202,15 +203,22 @@ async def bedrock_complete_if_cache(
return response["output"]["message"]["content"][0]["text"]
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
return hf_model, hf_tokenizer
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = model
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
hf_model, hf_tokenizer = initialize_hf_model(model_name)
if hf_tokenizer.pad_token is None:
# print("use eos token")
hf_tokenizer.pad_token = hf_tokenizer.eos_token
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt: