feat(examples): support siliconcloud free API
This commit is contained in:
@@ -2,8 +2,11 @@ import os
|
||||
import copy
|
||||
import json
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import ollama
|
||||
import base64
|
||||
import struct
|
||||
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
||||
from tenacity import (
|
||||
retry,
|
||||
@@ -312,7 +315,7 @@ async def ollama_model_complete(
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def openai_embedding(
|
||||
@@ -332,6 +335,49 @@ async def openai_embedding(
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def siliconcloud_embedding(
|
||||
texts: list[str],
|
||||
model: str = "netease-youdao/bce-embedding-base_v1",
|
||||
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
|
||||
max_token_size: int = 512,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key and not api_key.startswith('Bearer '):
|
||||
api_key = 'Bearer ' + api_key
|
||||
|
||||
headers = {
|
||||
"Authorization": api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
truncate_texts = [text[0:max_token_size] for text in texts]
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": truncate_texts,
|
||||
"encoding_format": "base64"
|
||||
}
|
||||
|
||||
base64_strings = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||
content = await response.json()
|
||||
if 'code' in content:
|
||||
raise ValueError(content)
|
||||
base64_strings = [item['embedding'] for item in content['data']]
|
||||
|
||||
embeddings = []
|
||||
for string in base64_strings:
|
||||
decode_bytes = base64.b64decode(string)
|
||||
n = len(decode_bytes) // 4
|
||||
float_array = struct.unpack('<' + 'f' * n, decode_bytes)
|
||||
embeddings.append(float_array)
|
||||
return np.array(embeddings)
|
||||
|
||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
# @retry(
|
||||
|
Reference in New Issue
Block a user