Merge pull request #360 from ahmadhatahet/azure_openai_embedding

Azure OpenAI Embedding
This commit is contained in:
zrguo
2024-12-02 16:02:13 +08:00
committed by GitHub

View File

@@ -101,12 +101,15 @@ async def azure_openai_complete_if_cache(
history_messages=[], history_messages=[],
base_url=None, base_url=None,
api_key=None, api_key=None,
api_version=None,
**kwargs, **kwargs,
): ):
if api_key: if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url: if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI( openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
@@ -585,7 +588,7 @@ async def openai_embedding(
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -596,11 +599,14 @@ async def azure_openai_embedding(
model: str = "text-embedding-3-small", model: str = "text-embedding-3-small",
base_url: str = None, base_url: str = None,
api_key: str = None, api_key: str = None,
api_version: str = None,
) -> np.ndarray: ) -> np.ndarray:
if api_key: if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url: if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI( openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),