run precommit to fix linting issues

This commit is contained in:
Saifeddine ALOUI
2025-01-11 01:37:07 +01:00
parent e0e656ab01
commit 224fce9b1b
2 changed files with 51 additions and 34 deletions

View File

@@ -23,21 +23,25 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.status import HTTP_403_FORBIDDEN from starlette.status import HTTP_403_FORBIDDEN
def get_default_host(binding_type: str) -> str: def get_default_host(binding_type: str) -> str:
default_hosts = { default_hosts = {
"ollama": "http://localhost:11434", "ollama": "http://localhost:11434",
"lollms": "http://localhost:9600", "lollms": "http://localhost:9600",
"azure_openai": "https://api.openai.com/v1", "azure_openai": "https://api.openai.com/v1",
"openai": "https://api.openai.com/v1" "openai": "https://api.openai.com/v1",
} }
return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown return default_hosts.get(
binding_type, "http://localhost:11434"
) # fallback to ollama if unknown
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories" description="LightRAG FastAPI Server with separate working and input directories"
) )
#Start by the bindings # Start by the bindings
parser.add_argument( parser.add_argument(
"--llm-binding", "--llm-binding",
default="ollama", default="ollama",
@@ -48,7 +52,7 @@ def parse_args():
default="ollama", default="ollama",
help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)",
) )
# Parse just these arguments first # Parse just these arguments first
temp_args, _ = parser.parse_known_args() temp_args, _ = parser.parse_known_args()
@@ -152,19 +156,17 @@ def parse_args():
# Optional https parameters # Optional https parameters
parser.add_argument( parser.add_argument(
"--ssl", "--ssl", action="store_true", help="Enable HTTPS (default: False)"
action="store_true",
help="Enable HTTPS (default: False)"
) )
parser.add_argument( parser.add_argument(
"--ssl-certfile", "--ssl-certfile",
default=None, default=None,
help="Path to SSL certificate file (required if --ssl is enabled)" help="Path to SSL certificate file (required if --ssl is enabled)",
) )
parser.add_argument( parser.add_argument(
"--ssl-keyfile", "--ssl-keyfile",
default=None, default=None,
help="Path to SSL private key file (required if --ssl is enabled)" help="Path to SSL private key file (required if --ssl is enabled)",
) )
return parser.parse_args() return parser.parse_args()
@@ -261,17 +263,17 @@ def create_app(args):
if args.embedding_binding not in ["lollms", "ollama", "openai"]: if args.embedding_binding not in ["lollms", "ollama", "openai"]:
raise Exception("embedding binding not supported") raise Exception("embedding binding not supported")
# Add SSL validation # Add SSL validation
if args.ssl: if args.ssl:
if not args.ssl_certfile or not args.ssl_keyfile: if not args.ssl_certfile or not args.ssl_keyfile:
raise Exception("SSL certificate and key files must be provided when SSL is enabled") raise Exception(
"SSL certificate and key files must be provided when SSL is enabled"
)
if not os.path.exists(args.ssl_certfile): if not os.path.exists(args.ssl_certfile):
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
if not os.path.exists(args.ssl_keyfile): if not os.path.exists(args.ssl_keyfile):
raise Exception(f"SSL key file not found: {args.ssl_keyfile}") raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
@@ -309,33 +311,48 @@ def create_app(args):
# Initialize document manager # Initialize document manager
doc_manager = DocumentManager(args.input_dir) doc_manager = DocumentManager(args.input_dir)
# Initialize RAG # Initialize RAG
rag = LightRAG( rag = LightRAG(
working_dir=args.working_dir, working_dir=args.working_dir,
llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache, llm_model_func=lollms_model_complete
if args.llm_binding == "lollms"
else ollama_model_complete
if args.llm_binding == "ollama"
else azure_openai_complete_if_cache
if args.llm_binding == "azure_openai"
else openai_complete_if_cache,
llm_model_name=args.llm_model, llm_model_name=args.llm_model,
llm_model_max_async=args.max_async, llm_model_max_async=args.max_async,
llm_model_max_token_size=args.max_tokens, llm_model_max_token_size=args.max_tokens,
llm_model_kwargs={ llm_model_kwargs={
"host": args.llm_binding_host, "host": args.llm_binding_host,
"timeout":args.timeout, "timeout": args.timeout,
"options": {"num_ctx": args.max_tokens}, "options": {"num_ctx": args.max_tokens},
}, },
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=args.embedding_dim, embedding_dim=args.embedding_dim,
max_token_size=args.max_embed_tokens, max_token_size=args.max_embed_tokens,
func=lambda texts: lollms_embed( func=lambda texts: lollms_embed(
texts, embed_model=args.embedding_model, host=args.embedding_binding_host texts,
) if args.llm_binding=="lollms" else ollama_embed( embed_model=args.embedding_model,
texts, embed_model=args.embedding_model, host=args.embedding_binding_host host=args.embedding_binding_host,
) if args.llm_binding=="ollama" else azure_openai_embedding(
texts, model=args.embedding_model # no host is used for openai
) if args.llm_binding=="azure_openai" else openai_embedding(
texts, model=args.embedding_model # no host is used for openai
) )
if args.llm_binding == "lollms"
else ollama_embed(
texts,
embed_model=args.embedding_model,
host=args.embedding_binding_host,
)
if args.llm_binding == "ollama"
else azure_openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
)
if args.llm_binding == "azure_openai"
else openai_embedding(
texts,
model=args.embedding_model, # no host is used for openai
),
), ),
) )
@@ -568,12 +585,10 @@ def create_app(args):
"llm_binding": args.llm_binding, "llm_binding": args.llm_binding,
"llm_binding_host": args.llm_binding_host, "llm_binding_host": args.llm_binding_host,
"llm_model": args.llm_model, "llm_model": args.llm_model,
# embedding model configuration binding/host address (if applicable)/model (if applicable) # embedding model configuration binding/host address (if applicable)/model (if applicable)
"embedding_binding": args.embedding_binding, "embedding_binding": args.embedding_binding,
"embedding_binding_host": args.embedding_binding_host, "embedding_binding_host": args.embedding_binding_host,
"embedding_model": args.embedding_model, "embedding_model": args.embedding_model,
"max_tokens": args.max_tokens, "max_tokens": args.max_tokens,
}, },
} }
@@ -590,12 +605,14 @@ def main():
"app": app, "app": app,
"host": args.host, "host": args.host,
"port": args.port, "port": args.port,
} }
if args.ssl: if args.ssl:
uvicorn_config.update({ uvicorn_config.update(
"ssl_certfile": args.ssl_certfile, {
"ssl_keyfile": args.ssl_keyfile, "ssl_certfile": args.ssl_certfile,
}) "ssl_keyfile": args.ssl_keyfile,
}
)
uvicorn.run(**uvicorn_config) uvicorn.run(**uvicorn_config)

View File

@@ -336,7 +336,6 @@ async def hf_model_if_cache(
(RateLimitError, APIConnectionError, APITimeoutError) (RateLimitError, APIConnectionError, APITimeoutError)
), ),
) )
async def ollama_model_if_cache( async def ollama_model_if_cache(
model, model,
prompt, prompt,
@@ -411,6 +410,7 @@ async def lollms_model_if_cache(
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
if stream: if stream:
async def inner(): async def inner():
async with session.post( async with session.post(
f"{base_url}/lollms_generate", json=request_data f"{base_url}/lollms_generate", json=request_data