Merge pull request #360 from ahmadhatahet/azure_openai_embedding

Azure OpenAI Embedding
This commit is contained in:
zrguo
2024-12-02 16:02:13 +08:00
committed by GitHub

View File

@@ -101,12 +101,15 @@ async def azure_openai_complete_if_cache(
history_messages=[],
base_url=None,
api_key=None,
api_version=None,
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
@@ -585,7 +588,7 @@ async def openai_embedding(
return np.array([dp.embedding for dp in response.data])
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -596,11 +599,14 @@ async def azure_openai_embedding(
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
api_version: str = None,
) -> np.ndarray:
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),