run precommit to fix linting issues

This commit is contained in:
Saifeddine ALOUI
2025-01-11 01:37:07 +01:00
parent e0e656ab01
commit 224fce9b1b
2 changed files with 51 additions and 34 deletions

View File

@@ -23,21 +23,25 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": "http://localhost:11434",
"lollms": "http://localhost:9600",
"azure_openai": "https://api.openai.com/v1",
"openai": "https://api.openai.com/v1"
"openai": "https://api.openai.com/v1",
}
return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown
return default_hosts.get(
binding_type, "http://localhost:11434"
) # fallback to ollama if unknown
def parse_args():
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
#Start by the bindings
# Start by the bindings
parser.add_argument(
"--llm-binding",
default="ollama",
@@ -152,19 +156,17 @@ def parse_args():
# Optional https parameters
parser.add_argument(
"--ssl",
action="store_true",
help="Enable HTTPS (default: False)"
"--ssl", action="store_true", help="Enable HTTPS (default: False)"
)
parser.add_argument(
"--ssl-certfile",
default=None,
help="Path to SSL certificate file (required if --ssl is enabled)"
help="Path to SSL certificate file (required if --ssl is enabled)",
)
parser.add_argument(
"--ssl-keyfile",
default=None,
help="Path to SSL private key file (required if --ssl is enabled)"
help="Path to SSL private key file (required if --ssl is enabled)",
)
return parser.parse_args()
@@ -261,17 +263,17 @@ def create_app(args):
if args.embedding_binding not in ["lollms", "ollama", "openai"]:
raise Exception("embedding binding not supported")
# Add SSL validation
if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile:
raise Exception("SSL certificate and key files must be provided when SSL is enabled")
raise Exception(
"SSL certificate and key files must be provided when SSL is enabled"
)
if not os.path.exists(args.ssl_certfile):
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
# Setup logging
logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
@@ -309,33 +311,48 @@ def create_app(args):
# Initialize document manager
doc_manager = DocumentManager(args.input_dir)
# Initialize RAG
rag = LightRAG(
working_dir=args.working_dir,
llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache,
llm_model_func=lollms_model_complete
if args.llm_binding == "lollms"
else ollama_model_complete
if args.llm_binding == "ollama"
else azure_openai_complete_if_cache
if args.llm_binding == "azure_openai"
else openai_complete_if_cache,
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={
"host": args.llm_binding_host,
"timeout":args.timeout,
"timeout": args.timeout,
"options": {"num_ctx": args.max_tokens},
},
embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed(
texts, embed_model=args.embedding_model, host=args.embedding_binding_host
) if args.llm_binding=="lollms" else ollama_embed(
texts, embed_model=args.embedding_model, host=args.embedding_binding_host
) if args.llm_binding=="ollama" else azure_openai_embedding(
texts, model=args.embedding_model # no host is used for openai
) if args.llm_binding=="azure_openai" else openai_embedding(
texts, model=args.embedding_model # no host is used for openai
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.llm_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.llm_binding == "ollama"
else azure_openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
)
if args.llm_binding == "azure_openai"
else openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
),
),
)
@@ -568,12 +585,10 @@ def create_app(args):
"llm_binding": args.llm_binding,
"llm_binding_host": args.llm_binding_host,
"llm_model": args.llm_model,
# embedding model configuration binding/host address (if applicable)/model (if applicable)
"embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model,
"max_tokens": args.max_tokens,
},
}
@@ -592,10 +607,12 @@ def main():
"port": args.port,
}
if args.ssl:
uvicorn_config.update({
uvicorn_config.update(
{
"ssl_certfile": args.ssl_certfile,
"ssl_keyfile": args.ssl_keyfile,
})
}
)
uvicorn.run(**uvicorn_config)

View File

@@ -336,7 +336,6 @@ async def hf_model_if_cache(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def ollama_model_if_cache(
model,
prompt,
@@ -411,6 +410,7 @@ async def lollms_model_if_cache(
async with aiohttp.ClientSession(timeout=timeout) as session:
if stream:
async def inner():
async with session.post(
f"{base_url}/lollms_generate", json=request_data