From 1dd927eb9d5fd8c8af0391a0c3de3e0e2fc1e2b0 Mon Sep 17 00:00:00 2001 From: Abyl Ikhsanov Date: Mon, 21 Oct 2024 20:40:49 +0200 Subject: [PATCH] Update llm.py --- lightrag/llm.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index be801e0c..51c48b84 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,7 +4,7 @@ import json import aioboto3 import numpy as np import ollama -from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout +from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI from tenacity import ( retry, stop_after_attempt, @@ -61,6 +61,49 @@ async def openai_complete_if_cache( ) return response.choices[0].message.content +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def azure_openai_complete_if_cache(model, + prompt, + system_prompt=None, + history_messages=[], + base_url=None, + api_key=None, + **kwargs): + if api_key: + os.environ["AZURE_OPENAI_API_KEY"] = api_key + if base_url: + os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + + openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + if prompt is not None: + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + return response.choices[0].message.content class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" @@ -270,6 +313,16 @@ async def gpt_4o_mini_complete( **kwargs, ) +async def azure_openai_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await azure_openai_complete_if_cache( + "conversation-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs @@ -332,6 +385,32 @@ async def openai_embedding( ) return np.array([dp.embedding for dp in response.data]) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def azure_openai_embedding( + texts: list[str], + model: str = "text-embedding-3-small", + base_url: str = None, + api_key: str = None, +) -> np.ndarray: + if api_key: + os.environ["AZURE_OPENAI_API_KEY"] = api_key + if base_url: + os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + + openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + + response = await openai_async_client.embeddings.create( + model=model, input=texts, encoding_format="float" + ) + return np.array([dp.embedding for dp in response.data]) + # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @retry(