Fixed a bug introduced by a modification by someone else in azure_openai_complete (please make sure you test before commiting code)

Added api_key to lollms, ollama, openai for both llm and embedding bindings allowing to use api key protected services.
This commit is contained in:
Saifeddine ALOUI
2025-01-19 23:24:37 +01:00
parent 34c051b0b9
commit 9cae05e1ff
2 changed files with 22 additions and 8 deletions

View File

@@ -299,7 +299,7 @@ def parse_args() -> argparse.Namespace:
)
default_llm_api_key = get_env_value(
"LLM_BINDING_API_KEY", ""
"LLM_BINDING_API_KEY", None
)
parser.add_argument(
@@ -649,22 +649,26 @@ def create_app(args):
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
api_key = args.embedding_binding_api_key
)
if args.embedding_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
api_key = args.embedding_binding_api_key
)
if args.embedding_binding == "ollama"
else azure_openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
model=args.embedding_model, # no host is used for openai,
api_key = args.embedding_binding_api_key
)
if args.embedding_binding == "azure_openai"
else openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
model=args.embedding_model, # no host is used for openai,
api_key = args.embedding_binding_api_key
),
)
@@ -682,6 +686,7 @@ def create_app(args):
"host": args.llm_binding_host,
"timeout": args.timeout,
"options": {"num_ctx": args.max_tokens},
"api_key": args.llm_binding_api_key
},
embedding_func=embedding_func,
)

View File

@@ -349,7 +349,9 @@ async def ollama_model_if_cache(
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
kwargs.pop("hashing_kv", None)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
api_key = kwargs.pop("api_key", None)
headers={'Authorization': f'Bearer {api_key}'} if api_key else None
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@@ -380,6 +382,8 @@ async def lollms_model_if_cache(
"""Client implementation for lollms generation."""
stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None)
headers={'Authorization': f'Bearer {api_key}'} if api_key else None
# Extract lollms specific parameters
request_data = {
@@ -408,7 +412,7 @@ async def lollms_model_if_cache(
request_data["prompt"] = full_prompt
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(timeout=timeout,headers=headers) as session:
if stream:
async def inner():
@@ -622,7 +626,7 @@ async def nvidia_openai_complete(
async def azure_openai_complete(
model: str = "gpt-4o-mini", prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
model: str = "gpt-4o-mini", prompt="", system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache(
@@ -1148,6 +1152,9 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key",None)
headers = {"Authorization": api_key, "Content-Type": "application/json"} if api_key else None
kwargs["headers"]=headers
ollama_client = ollama.Client(**kwargs)
data = ollama_client.embed(model=embed_model, input=texts)
return data["embeddings"]
@@ -1168,13 +1175,15 @@ async def lollms_embed(
Returns:
np.ndarray: Array of embeddings
"""
async with aiohttp.ClientSession() as session:
api_key = kwargs.pop("api_key",None)
headers = {"Authorization": api_key, "Content-Type": "application/json"} if api_key else None
async with aiohttp.ClientSession(headers=headers) as session:
embeddings = []
for text in texts:
request_data = {"text": text}
async with session.post(
f"{base_url}/lollms_embed", json=request_data
f"{base_url}/lollms_embed", json=request_data,
) as response:
result = await response.json()
embeddings.append(result["vector"])