From 9cae05e1ff72bf830b500847e2a8f0ca73bba95c Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sun, 19 Jan 2025 23:24:37 +0100 Subject: [PATCH] Fixed a bug introduced by a modification by someone else in azure_openai_complete (please make sure you test before commiting code) Added api_key to lollms, ollama, openai for both llm and embedding bindings allowing to use api key protected services. --- lightrag/api/lightrag_server.py | 11 ++++++++--- lightrag/llm.py | 19 ++++++++++++++----- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 2e0aabd7..e4cbac57 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -299,7 +299,7 @@ def parse_args() -> argparse.Namespace: ) default_llm_api_key = get_env_value( - "LLM_BINDING_API_KEY", "" + "LLM_BINDING_API_KEY", None ) parser.add_argument( @@ -649,22 +649,26 @@ def create_app(args): texts, embed_model=args.embedding_model, host=args.embedding_binding_host, + api_key = args.embedding_binding_api_key ) if args.embedding_binding == "lollms" else ollama_embed( texts, embed_model=args.embedding_model, host=args.embedding_binding_host, + api_key = args.embedding_binding_api_key ) if args.embedding_binding == "ollama" else azure_openai_embedding( texts, - model=args.embedding_model, # no host is used for openai + model=args.embedding_model, # no host is used for openai, + api_key = args.embedding_binding_api_key ) if args.embedding_binding == "azure_openai" else openai_embedding( texts, - model=args.embedding_model, # no host is used for openai + model=args.embedding_model, # no host is used for openai, + api_key = args.embedding_binding_api_key ), ) @@ -682,6 +686,7 @@ def create_app(args): "host": args.llm_binding_host, "timeout": args.timeout, "options": {"num_ctx": args.max_tokens}, + "api_key": args.llm_binding_api_key }, embedding_func=embedding_func, ) diff --git a/lightrag/llm.py b/lightrag/llm.py index 1f52d4ae..02fe3961 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -349,7 +349,9 @@ async def ollama_model_if_cache( host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) kwargs.pop("hashing_kv", None) - ollama_client = ollama.AsyncClient(host=host, timeout=timeout) + api_key = kwargs.pop("api_key", None) + headers={'Authorization': f'Bearer {api_key}'} if api_key else None + ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -380,6 +382,8 @@ async def lollms_model_if_cache( """Client implementation for lollms generation.""" stream = True if kwargs.get("stream") else False + api_key = kwargs.pop("api_key", None) + headers={'Authorization': f'Bearer {api_key}'} if api_key else None # Extract lollms specific parameters request_data = { @@ -408,7 +412,7 @@ async def lollms_model_if_cache( request_data["prompt"] = full_prompt timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None)) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession(timeout=timeout,headers=headers) as session: if stream: async def inner(): @@ -622,7 +626,7 @@ async def nvidia_openai_complete( async def azure_openai_complete( - model: str = "gpt-4o-mini", prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + model: str = "gpt-4o-mini", prompt="", system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: keyword_extraction = kwargs.pop("keyword_extraction", None) result = await azure_openai_complete_if_cache( @@ -1148,6 +1152,9 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: + api_key = kwargs.pop("api_key",None) + headers = {"Authorization": api_key, "Content-Type": "application/json"} if api_key else None + kwargs["headers"]=headers ollama_client = ollama.Client(**kwargs) data = ollama_client.embed(model=embed_model, input=texts) return data["embeddings"] @@ -1168,13 +1175,15 @@ async def lollms_embed( Returns: np.ndarray: Array of embeddings """ - async with aiohttp.ClientSession() as session: + api_key = kwargs.pop("api_key",None) + headers = {"Authorization": api_key, "Content-Type": "application/json"} if api_key else None + async with aiohttp.ClientSession(headers=headers) as session: embeddings = [] for text in texts: request_data = {"text": text} async with session.post( - f"{base_url}/lollms_embed", json=request_data + f"{base_url}/lollms_embed", json=request_data, ) as response: result = await response.json() embeddings.append(result["vector"])