diff --git a/lightrag/llm.py b/lightrag/llm.py index 06d75d01..5c923acb 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -5,9 +5,12 @@ import aioboto3 import aiohttp import numpy as np import ollama + +from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI + import base64 import struct -from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout + from tenacity import ( retry, stop_after_attempt, @@ -64,6 +67,49 @@ async def openai_complete_if_cache( ) return response.choices[0].message.content +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def azure_openai_complete_if_cache(model, + prompt, + system_prompt=None, + history_messages=[], + base_url=None, + api_key=None, + **kwargs): + if api_key: + os.environ["AZURE_OPENAI_API_KEY"] = api_key + if base_url: + os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + + openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + if prompt is not None: + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + return response.choices[0].message.content class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" @@ -273,6 +319,16 @@ async def gpt_4o_mini_complete( **kwargs, ) +async def azure_openai_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await azure_openai_complete_if_cache( + "conversation-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs @@ -335,6 +391,34 @@ async def openai_embedding( ) return np.array([dp.embedding for dp in response.data]) + +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def azure_openai_embedding( + texts: list[str], + model: str = "text-embedding-3-small", + base_url: str = None, + api_key: str = None, +) -> np.ndarray: + if api_key: + os.environ["AZURE_OPENAI_API_KEY"] = api_key + if base_url: + os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + + openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION")) + + response = await openai_async_client.embeddings.create( + model=model, input=texts, encoding_format="float" + ) + return np.array([dp.embedding for dp in response.data]) + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -379,6 +463,7 @@ async def siliconcloud_embedding( embeddings.append(float_array) return np.array(embeddings) + # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @retry( # stop=stop_after_attempt(3),