fixed linting

This commit is contained in:
Saifeddine ALOUI
2025-01-20 00:26:28 +01:00
parent 2dfbbec407
commit 70425b0357
3 changed files with 39 additions and 33 deletions

View File

@@ -23,12 +23,8 @@ Requirements:
- Lightrag - Lightrag
""" """
from pathlib import Path
from typing import Optional, List, Dict, Union, Any
from datetime import datetime
# Tool version # Tool version
__version__ = "1.0.0" __version__ = "1.0.0"
__author__ = "ParisNeo" __author__ = "ParisNeo"
__author_email__ = "parisneoai@gmail.com" __author_email__ = "parisneoai@gmail.com"
__description__ = "Lightrag integration for OpenWebui" __description__ = "Lightrag integration for OpenWebui"

View File

@@ -297,15 +297,13 @@ def parse_args() -> argparse.Namespace:
default=default_llm_host, default=default_llm_host,
help=f"llm server host URL (default: from env or {default_llm_host})", help=f"llm server host URL (default: from env or {default_llm_host})",
) )
default_llm_api_key = get_env_value( default_llm_api_key = get_env_value("LLM_BINDING_API_KEY", None)
"LLM_BINDING_API_KEY", None
)
parser.add_argument( parser.add_argument(
"--llm-binding-api-key", "--llm-binding-api-key",
default=default_llm_api_key, default=default_llm_api_key,
help=f"llm server API key (default: from env or empty string)", help="llm server API key (default: from env or empty string)",
) )
parser.add_argument( parser.add_argument(
@@ -323,14 +321,12 @@ def parse_args() -> argparse.Namespace:
default=default_embedding_host, default=default_embedding_host,
help=f"embedding server host URL (default: from env or {default_embedding_host})", help=f"embedding server host URL (default: from env or {default_embedding_host})",
) )
default_embedding_api_key = get_env_value( default_embedding_api_key = get_env_value("EMBEDDING_BINDING_API_KEY", "")
"EMBEDDING_BINDING_API_KEY", ""
)
parser.add_argument( parser.add_argument(
"--embedding-binding-api-key", "--embedding-binding-api-key",
default=default_embedding_api_key, default=default_embedding_api_key,
help=f"embedding server API key (default: from env or empty string)", help="embedding server API key (default: from env or empty string)",
) )
parser.add_argument( parser.add_argument(
@@ -649,26 +645,26 @@ def create_app(args):
texts, texts,
embed_model=args.embedding_model, embed_model=args.embedding_model,
host=args.embedding_binding_host, host=args.embedding_binding_host,
api_key = args.embedding_binding_api_key api_key=args.embedding_binding_api_key,
) )
if args.embedding_binding == "lollms" if args.embedding_binding == "lollms"
else ollama_embed( else ollama_embed(
texts, texts,
embed_model=args.embedding_model, embed_model=args.embedding_model,
host=args.embedding_binding_host, host=args.embedding_binding_host,
api_key = args.embedding_binding_api_key api_key=args.embedding_binding_api_key,
) )
if args.embedding_binding == "ollama" if args.embedding_binding == "ollama"
else azure_openai_embedding( else azure_openai_embedding(
texts, texts,
model=args.embedding_model, # no host is used for openai, model=args.embedding_model, # no host is used for openai,
api_key = args.embedding_binding_api_key api_key=args.embedding_binding_api_key,
) )
if args.embedding_binding == "azure_openai" if args.embedding_binding == "azure_openai"
else openai_embedding( else openai_embedding(
texts, texts,
model=args.embedding_model, # no host is used for openai, model=args.embedding_model, # no host is used for openai,
api_key = args.embedding_binding_api_key api_key=args.embedding_binding_api_key,
), ),
) )
@@ -686,7 +682,7 @@ def create_app(args):
"host": args.llm_binding_host, "host": args.llm_binding_host,
"timeout": args.timeout, "timeout": args.timeout,
"options": {"num_ctx": args.max_tokens}, "options": {"num_ctx": args.max_tokens},
"api_key": args.llm_binding_api_key "api_key": args.llm_binding_api_key,
}, },
embedding_func=embedding_func, embedding_func=embedding_func,
) )

View File

@@ -349,8 +349,8 @@ async def ollama_model_if_cache(
host = kwargs.pop("host", None) host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None) timeout = kwargs.pop("timeout", None)
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
headers={'Authorization': f'Bearer {api_key}'} if api_key else None headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
messages = [] messages = []
if system_prompt: if system_prompt:
@@ -382,8 +382,8 @@ async def lollms_model_if_cache(
"""Client implementation for lollms generation.""" """Client implementation for lollms generation."""
stream = True if kwargs.get("stream") else False stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
headers={'Authorization': f'Bearer {api_key}'} if api_key else None headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
# Extract lollms specific parameters # Extract lollms specific parameters
request_data = { request_data = {
@@ -412,7 +412,7 @@ async def lollms_model_if_cache(
request_data["prompt"] = full_prompt request_data["prompt"] = full_prompt
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None)) timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
async with aiohttp.ClientSession(timeout=timeout,headers=headers) as session: async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
if stream: if stream:
async def inner(): async def inner():
@@ -626,7 +626,12 @@ async def nvidia_openai_complete(
async def azure_openai_complete( async def azure_openai_complete(
model: str = "gpt-4o-mini", prompt="", system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs model: str = "gpt-4o-mini",
prompt="",
system_prompt=None,
history_messages=[],
keyword_extraction=False,
**kwargs,
) -> str: ) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None) keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache( result = await azure_openai_complete_if_cache(
@@ -1152,9 +1157,13 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key",None) api_key = kwargs.pop("api_key", None)
headers = {"Authorization": api_key, "Content-Type": "application/json"} if api_key else None headers = (
kwargs["headers"]=headers {"Authorization": api_key, "Content-Type": "application/json"}
if api_key
else None
)
kwargs["headers"] = headers
ollama_client = ollama.Client(**kwargs) ollama_client = ollama.Client(**kwargs)
data = ollama_client.embed(model=embed_model, input=texts) data = ollama_client.embed(model=embed_model, input=texts)
return data["embeddings"] return data["embeddings"]
@@ -1175,15 +1184,20 @@ async def lollms_embed(
Returns: Returns:
np.ndarray: Array of embeddings np.ndarray: Array of embeddings
""" """
api_key = kwargs.pop("api_key",None) api_key = kwargs.pop("api_key", None)
headers = {"Authorization": api_key, "Content-Type": "application/json"} if api_key else None headers = (
{"Authorization": api_key, "Content-Type": "application/json"}
if api_key
else None
)
async with aiohttp.ClientSession(headers=headers) as session: async with aiohttp.ClientSession(headers=headers) as session:
embeddings = [] embeddings = []
for text in texts: for text in texts:
request_data = {"text": text} request_data = {"text": text}
async with session.post( async with session.post(
f"{base_url}/lollms_embed", json=request_data, f"{base_url}/lollms_embed",
json=request_data,
) as response: ) as response:
result = await response.json() result = await response.json()
embeddings.append(result["vector"]) embeddings.append(result["vector"])