Update llm.py

This commit is contained in:
Abyl Ikhsanov
2024-10-21 20:40:49 +02:00
committed by GitHub
parent 6527e75e81
commit 1dd927eb9d

View File

@@ -4,7 +4,7 @@ import json
import aioboto3 import aioboto3
import numpy as np import numpy as np
import ollama import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@@ -61,6 +61,49 @@ async def openai_complete_if_cache(
) )
return response.choices[0].message.content return response.choices[0].message.content
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def azure_openai_complete_if_cache(model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
return response.choices[0].message.content
class BedrockError(Exception): class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock""" """Generic error for issues related to Amazon Bedrock"""
@@ -270,6 +313,16 @@ async def gpt_4o_mini_complete(
**kwargs, **kwargs,
) )
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"conversation-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def bedrock_complete( async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
@@ -332,6 +385,32 @@ async def openai_embedding(
) )
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def azure_openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry( # @retry(