fix hf bug
This commit is contained in:
@@ -19,7 +19,7 @@ async def llm_model_func(
|
|||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||||
base_url="https://api.siliconflow.cn/v1/",
|
base_url="https://api.siliconflow.cn/v1/",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -29,7 +29,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|||||||
return await siliconcloud_embedding(
|
return await siliconcloud_embedding(
|
||||||
texts,
|
texts,
|
||||||
model="netease-youdao/bce-embedding-base_v1",
|
model="netease-youdao/bce-embedding-base_v1",
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||||
max_token_size=512
|
max_token_size=512
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
|
from functools import lru_cache
|
||||||
import json
|
import json
|
||||||
import aioboto3
|
import aioboto3
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -202,15 +203,22 @@ async def bedrock_complete_if_cache(
|
|||||||
return response["output"]["message"]["content"][0]["text"]
|
return response["output"]["message"]["content"][0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def initialize_hf_model(model_name):
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
|
||||||
|
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
|
||||||
|
|
||||||
|
return hf_model, hf_tokenizer
|
||||||
|
|
||||||
|
|
||||||
async def hf_model_if_cache(
|
async def hf_model_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = model
|
model_name = model
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||||
if hf_tokenizer.pad_token is None:
|
if hf_tokenizer.pad_token is None:
|
||||||
# print("use eos token")
|
# print("use eos token")
|
||||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
|
Reference in New Issue
Block a user