Merge pull request #360 from ahmadhatahet/azure_openai_embedding
Azure OpenAI Embedding
This commit is contained in:
@@ -101,12 +101,15 @@ async def azure_openai_complete_if_cache(
|
|||||||
history_messages=[],
|
history_messages=[],
|
||||||
base_url=None,
|
base_url=None,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
|
api_version=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||||
if base_url:
|
if base_url:
|
||||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||||
|
if api_version:
|
||||||
|
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
||||||
|
|
||||||
openai_async_client = AsyncAzureOpenAI(
|
openai_async_client = AsyncAzureOpenAI(
|
||||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||||
@@ -585,7 +588,7 @@ async def openai_embedding(
|
|||||||
return np.array([dp.embedding for dp in response.data])
|
return np.array([dp.embedding for dp in response.data])
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
@@ -596,11 +599,14 @@ async def azure_openai_embedding(
|
|||||||
model: str = "text-embedding-3-small",
|
model: str = "text-embedding-3-small",
|
||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
|
api_version: str = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||||
if base_url:
|
if base_url:
|
||||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||||
|
if api_version:
|
||||||
|
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
||||||
|
|
||||||
openai_async_client = AsyncAzureOpenAI(
|
openai_async_client = AsyncAzureOpenAI(
|
||||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||||
|
Reference in New Issue
Block a user