From 224fce9b1b1a887a998d2ee818f0855c950422de Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sat, 11 Jan 2025 01:37:07 +0100 Subject: [PATCH] run precommit to fix linting issues --- lightrag/api/lightrag_server.py | 83 ++++++++++++++++++++------------- lightrag/llm.py | 2 +- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1f88e776..644e622d 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -23,21 +23,25 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN + def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": "http://localhost:11434", "lollms": "http://localhost:9600", "azure_openai": "https://api.openai.com/v1", - "openai": "https://api.openai.com/v1" + "openai": "https://api.openai.com/v1", } - return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown + return default_hosts.get( + binding_type, "http://localhost:11434" + ) # fallback to ollama if unknown + def parse_args(): parser = argparse.ArgumentParser( description="LightRAG FastAPI Server with separate working and input directories" ) - #Start by the bindings + # Start by the bindings parser.add_argument( "--llm-binding", default="ollama", @@ -48,7 +52,7 @@ def parse_args(): default="ollama", help="Embedding binding to be used. Supported: lollms, ollama, openai (default: ollama)", ) - + # Parse just these arguments first temp_args, _ = parser.parse_known_args() @@ -152,19 +156,17 @@ def parse_args(): # Optional https parameters parser.add_argument( - "--ssl", - action="store_true", - help="Enable HTTPS (default: False)" + "--ssl", action="store_true", help="Enable HTTPS (default: False)" ) parser.add_argument( "--ssl-certfile", default=None, - help="Path to SSL certificate file (required if --ssl is enabled)" + help="Path to SSL certificate file (required if --ssl is enabled)", ) parser.add_argument( "--ssl-keyfile", - default=None, - help="Path to SSL private key file (required if --ssl is enabled)" + default=None, + help="Path to SSL private key file (required if --ssl is enabled)", ) return parser.parse_args() @@ -261,17 +263,17 @@ def create_app(args): if args.embedding_binding not in ["lollms", "ollama", "openai"]: raise Exception("embedding binding not supported") - # Add SSL validation if args.ssl: if not args.ssl_certfile or not args.ssl_keyfile: - raise Exception("SSL certificate and key files must be provided when SSL is enabled") + raise Exception( + "SSL certificate and key files must be provided when SSL is enabled" + ) if not os.path.exists(args.ssl_certfile): raise Exception(f"SSL certificate file not found: {args.ssl_certfile}") if not os.path.exists(args.ssl_keyfile): raise Exception(f"SSL key file not found: {args.ssl_keyfile}") - - + # Setup logging logging.basicConfig( format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) @@ -309,33 +311,48 @@ def create_app(args): # Initialize document manager doc_manager = DocumentManager(args.input_dir) - - # Initialize RAG rag = LightRAG( working_dir=args.working_dir, - llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache, + llm_model_func=lollms_model_complete + if args.llm_binding == "lollms" + else ollama_model_complete + if args.llm_binding == "ollama" + else azure_openai_complete_if_cache + if args.llm_binding == "azure_openai" + else openai_complete_if_cache, llm_model_name=args.llm_model, llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, llm_model_kwargs={ "host": args.llm_binding_host, - "timeout":args.timeout, + "timeout": args.timeout, "options": {"num_ctx": args.max_tokens}, }, embedding_func=EmbeddingFunc( embedding_dim=args.embedding_dim, max_token_size=args.max_embed_tokens, func=lambda texts: lollms_embed( - texts, embed_model=args.embedding_model, host=args.embedding_binding_host - ) if args.llm_binding=="lollms" else ollama_embed( - texts, embed_model=args.embedding_model, host=args.embedding_binding_host - ) if args.llm_binding=="ollama" else azure_openai_embedding( - texts, model=args.embedding_model # no host is used for openai - ) if args.llm_binding=="azure_openai" else openai_embedding( - texts, model=args.embedding_model # no host is used for openai + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, ) - + if args.llm_binding == "lollms" + else ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.llm_binding == "ollama" + else azure_openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ) + if args.llm_binding == "azure_openai" + else openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ), ), ) @@ -568,12 +585,10 @@ def create_app(args): "llm_binding": args.llm_binding, "llm_binding_host": args.llm_binding_host, "llm_model": args.llm_model, - # embedding model configuration binding/host address (if applicable)/model (if applicable) "embedding_binding": args.embedding_binding, "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, - "max_tokens": args.max_tokens, }, } @@ -590,12 +605,14 @@ def main(): "app": app, "host": args.host, "port": args.port, - } + } if args.ssl: - uvicorn_config.update({ - "ssl_certfile": args.ssl_certfile, - "ssl_keyfile": args.ssl_keyfile, - }) + uvicorn_config.update( + { + "ssl_certfile": args.ssl_certfile, + "ssl_keyfile": args.ssl_keyfile, + } + ) uvicorn.run(**uvicorn_config) diff --git a/lightrag/llm.py b/lightrag/llm.py index 7a51d025..c49ed138 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -336,7 +336,6 @@ async def hf_model_if_cache( (RateLimitError, APIConnectionError, APITimeoutError) ), ) - async def ollama_model_if_cache( model, prompt, @@ -411,6 +410,7 @@ async def lollms_model_if_cache( async with aiohttp.ClientSession(timeout=timeout) as session: if stream: + async def inner(): async with session.post( f"{base_url}/lollms_generate", json=request_data