diff --git a/lightrag/llm.py b/lightrag/llm.py index 12a4d5a6..7dc8b886 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -101,12 +101,15 @@ async def azure_openai_complete_if_cache( history_messages=[], base_url=None, api_key=None, + api_version=None, **kwargs, ): if api_key: os.environ["AZURE_OPENAI_API_KEY"] = api_key if base_url: os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + if api_version: + os.environ["AZURE_OPENAI_API_VERSION"] = api_version openai_async_client = AsyncAzureOpenAI( azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), @@ -585,7 +588,7 @@ async def openai_embedding( return np.array([dp.embedding for dp in response.data]) -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -596,11 +599,14 @@ async def azure_openai_embedding( model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None, + api_version: str = None, ) -> np.ndarray: if api_key: os.environ["AZURE_OPENAI_API_KEY"] = api_key if base_url: os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + if api_version: + os.environ["AZURE_OPENAI_API_VERSION"] = api_version openai_async_client = AsyncAzureOpenAI( azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),