run precommit to fix linting issues
This commit is contained in:
@@ -23,21 +23,25 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def get_default_host(binding_type: str) -> str:
|
||||
default_hosts = {
|
||||
"ollama": "http://localhost:11434",
|
||||
"lollms": "http://localhost:9600",
|
||||
"azure_openai": "https://api.openai.com/v1",
|
||||
"openai": "https://api.openai.com/v1"
|
||||
"openai": "https://api.openai.com/v1",
|
||||
}
|
||||
return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown
|
||||
return default_hosts.get(
|
||||
binding_type, "http://localhost:11434"
|
||||
) # fallback to ollama if unknown
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LightRAG FastAPI Server with separate working and input directories"
|
||||
)
|
||||
|
||||
#Start by the bindings
|
||||
# Start by the bindings
|
||||
parser.add_argument(
|
||||
"--llm-binding",
|
||||
default="ollama",
|
||||
@@ -152,19 +156,17 @@ def parse_args():
|
||||
|
||||
# Optional https parameters
|
||||
parser.add_argument(
|
||||
"--ssl",
|
||||
action="store_true",
|
||||
help="Enable HTTPS (default: False)"
|
||||
"--ssl", action="store_true", help="Enable HTTPS (default: False)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-certfile",
|
||||
default=None,
|
||||
help="Path to SSL certificate file (required if --ssl is enabled)"
|
||||
help="Path to SSL certificate file (required if --ssl is enabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-keyfile",
|
||||
default=None,
|
||||
help="Path to SSL private key file (required if --ssl is enabled)"
|
||||
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -261,17 +263,17 @@ def create_app(args):
|
||||
if args.embedding_binding not in ["lollms", "ollama", "openai"]:
|
||||
raise Exception("embedding binding not supported")
|
||||
|
||||
|
||||
# Add SSL validation
|
||||
if args.ssl:
|
||||
if not args.ssl_certfile or not args.ssl_keyfile:
|
||||
raise Exception("SSL certificate and key files must be provided when SSL is enabled")
|
||||
raise Exception(
|
||||
"SSL certificate and key files must be provided when SSL is enabled"
|
||||
)
|
||||
if not os.path.exists(args.ssl_certfile):
|
||||
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
|
||||
if not os.path.exists(args.ssl_keyfile):
|
||||
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
|
||||
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
||||
@@ -309,33 +311,48 @@ def create_app(args):
|
||||
# Initialize document manager
|
||||
doc_manager = DocumentManager(args.input_dir)
|
||||
|
||||
|
||||
|
||||
# Initialize RAG
|
||||
rag = LightRAG(
|
||||
working_dir=args.working_dir,
|
||||
llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache,
|
||||
llm_model_func=lollms_model_complete
|
||||
if args.llm_binding == "lollms"
|
||||
else ollama_model_complete
|
||||
if args.llm_binding == "ollama"
|
||||
else azure_openai_complete_if_cache
|
||||
if args.llm_binding == "azure_openai"
|
||||
else openai_complete_if_cache,
|
||||
llm_model_name=args.llm_model,
|
||||
llm_model_max_async=args.max_async,
|
||||
llm_model_max_token_size=args.max_tokens,
|
||||
llm_model_kwargs={
|
||||
"host": args.llm_binding_host,
|
||||
"timeout":args.timeout,
|
||||
"timeout": args.timeout,
|
||||
"options": {"num_ctx": args.max_tokens},
|
||||
},
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
max_token_size=args.max_embed_tokens,
|
||||
func=lambda texts: lollms_embed(
|
||||
texts, embed_model=args.embedding_model, host=args.embedding_binding_host
|
||||
) if args.llm_binding=="lollms" else ollama_embed(
|
||||
texts, embed_model=args.embedding_model, host=args.embedding_binding_host
|
||||
) if args.llm_binding=="ollama" else azure_openai_embedding(
|
||||
texts, model=args.embedding_model # no host is used for openai
|
||||
) if args.llm_binding=="azure_openai" else openai_embedding(
|
||||
texts, model=args.embedding_model # no host is used for openai
|
||||
texts,
|
||||
embed_model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
)
|
||||
|
||||
if args.llm_binding == "lollms"
|
||||
else ollama_embed(
|
||||
texts,
|
||||
embed_model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
)
|
||||
if args.llm_binding == "ollama"
|
||||
else azure_openai_embedding(
|
||||
texts,
|
||||
model=args.embedding_model, # no host is used for openai
|
||||
)
|
||||
if args.llm_binding == "azure_openai"
|
||||
else openai_embedding(
|
||||
texts,
|
||||
model=args.embedding_model, # no host is used for openai
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -568,12 +585,10 @@ def create_app(args):
|
||||
"llm_binding": args.llm_binding,
|
||||
"llm_binding_host": args.llm_binding_host,
|
||||
"llm_model": args.llm_model,
|
||||
|
||||
# embedding model configuration binding/host address (if applicable)/model (if applicable)
|
||||
"embedding_binding": args.embedding_binding,
|
||||
"embedding_binding_host": args.embedding_binding_host,
|
||||
"embedding_model": args.embedding_model,
|
||||
|
||||
"max_tokens": args.max_tokens,
|
||||
},
|
||||
}
|
||||
@@ -592,10 +607,12 @@ def main():
|
||||
"port": args.port,
|
||||
}
|
||||
if args.ssl:
|
||||
uvicorn_config.update({
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
"ssl_keyfile": args.ssl_keyfile,
|
||||
})
|
||||
uvicorn_config.update(
|
||||
{
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
"ssl_keyfile": args.ssl_keyfile,
|
||||
}
|
||||
)
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
|
||||
|
@@ -336,7 +336,6 @@ async def hf_model_if_cache(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
|
||||
async def ollama_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
@@ -411,6 +410,7 @@ async def lollms_model_if_cache(
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
if stream:
|
||||
|
||||
async def inner():
|
||||
async with session.post(
|
||||
f"{base_url}/lollms_generate", json=request_data
|
||||
|
Reference in New Issue
Block a user