run precommit to fix linting issues
This commit is contained in:
@@ -23,14 +23,18 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from starlette.status import HTTP_403_FORBIDDEN
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
def get_default_host(binding_type: str) -> str:
|
def get_default_host(binding_type: str) -> str:
|
||||||
default_hosts = {
|
default_hosts = {
|
||||||
"ollama": "http://localhost:11434",
|
"ollama": "http://localhost:11434",
|
||||||
"lollms": "http://localhost:9600",
|
"lollms": "http://localhost:9600",
|
||||||
"azure_openai": "https://api.openai.com/v1",
|
"azure_openai": "https://api.openai.com/v1",
|
||||||
"openai": "https://api.openai.com/v1"
|
"openai": "https://api.openai.com/v1",
|
||||||
}
|
}
|
||||||
return default_hosts.get(binding_type, "http://localhost:11434") # fallback to ollama if unknown
|
return default_hosts.get(
|
||||||
|
binding_type, "http://localhost:11434"
|
||||||
|
) # fallback to ollama if unknown
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -152,19 +156,17 @@ def parse_args():
|
|||||||
|
|
||||||
# Optional https parameters
|
# Optional https parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssl",
|
"--ssl", action="store_true", help="Enable HTTPS (default: False)"
|
||||||
action="store_true",
|
|
||||||
help="Enable HTTPS (default: False)"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssl-certfile",
|
"--ssl-certfile",
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to SSL certificate file (required if --ssl is enabled)"
|
help="Path to SSL certificate file (required if --ssl is enabled)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssl-keyfile",
|
"--ssl-keyfile",
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to SSL private key file (required if --ssl is enabled)"
|
help="Path to SSL private key file (required if --ssl is enabled)",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@@ -261,17 +263,17 @@ def create_app(args):
|
|||||||
if args.embedding_binding not in ["lollms", "ollama", "openai"]:
|
if args.embedding_binding not in ["lollms", "ollama", "openai"]:
|
||||||
raise Exception("embedding binding not supported")
|
raise Exception("embedding binding not supported")
|
||||||
|
|
||||||
|
|
||||||
# Add SSL validation
|
# Add SSL validation
|
||||||
if args.ssl:
|
if args.ssl:
|
||||||
if not args.ssl_certfile or not args.ssl_keyfile:
|
if not args.ssl_certfile or not args.ssl_keyfile:
|
||||||
raise Exception("SSL certificate and key files must be provided when SSL is enabled")
|
raise Exception(
|
||||||
|
"SSL certificate and key files must be provided when SSL is enabled"
|
||||||
|
)
|
||||||
if not os.path.exists(args.ssl_certfile):
|
if not os.path.exists(args.ssl_certfile):
|
||||||
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
|
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
|
||||||
if not os.path.exists(args.ssl_keyfile):
|
if not os.path.exists(args.ssl_keyfile):
|
||||||
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
|
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
|
||||||
|
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
|
||||||
@@ -309,12 +311,16 @@ def create_app(args):
|
|||||||
# Initialize document manager
|
# Initialize document manager
|
||||||
doc_manager = DocumentManager(args.input_dir)
|
doc_manager = DocumentManager(args.input_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize RAG
|
# Initialize RAG
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=args.working_dir,
|
working_dir=args.working_dir,
|
||||||
llm_model_func=lollms_model_complete if args.llm_binding=="lollms" else ollama_model_complete if args.llm_binding=="ollama" else azure_openai_complete_if_cache if args.llm_binding=="azure_openai" else openai_complete_if_cache,
|
llm_model_func=lollms_model_complete
|
||||||
|
if args.llm_binding == "lollms"
|
||||||
|
else ollama_model_complete
|
||||||
|
if args.llm_binding == "ollama"
|
||||||
|
else azure_openai_complete_if_cache
|
||||||
|
if args.llm_binding == "azure_openai"
|
||||||
|
else openai_complete_if_cache,
|
||||||
llm_model_name=args.llm_model,
|
llm_model_name=args.llm_model,
|
||||||
llm_model_max_async=args.max_async,
|
llm_model_max_async=args.max_async,
|
||||||
llm_model_max_token_size=args.max_tokens,
|
llm_model_max_token_size=args.max_tokens,
|
||||||
@@ -327,15 +333,26 @@ def create_app(args):
|
|||||||
embedding_dim=args.embedding_dim,
|
embedding_dim=args.embedding_dim,
|
||||||
max_token_size=args.max_embed_tokens,
|
max_token_size=args.max_embed_tokens,
|
||||||
func=lambda texts: lollms_embed(
|
func=lambda texts: lollms_embed(
|
||||||
texts, embed_model=args.embedding_model, host=args.embedding_binding_host
|
texts,
|
||||||
) if args.llm_binding=="lollms" else ollama_embed(
|
embed_model=args.embedding_model,
|
||||||
texts, embed_model=args.embedding_model, host=args.embedding_binding_host
|
host=args.embedding_binding_host,
|
||||||
) if args.llm_binding=="ollama" else azure_openai_embedding(
|
|
||||||
texts, model=args.embedding_model # no host is used for openai
|
|
||||||
) if args.llm_binding=="azure_openai" else openai_embedding(
|
|
||||||
texts, model=args.embedding_model # no host is used for openai
|
|
||||||
)
|
)
|
||||||
|
if args.llm_binding == "lollms"
|
||||||
|
else ollama_embed(
|
||||||
|
texts,
|
||||||
|
embed_model=args.embedding_model,
|
||||||
|
host=args.embedding_binding_host,
|
||||||
|
)
|
||||||
|
if args.llm_binding == "ollama"
|
||||||
|
else azure_openai_embedding(
|
||||||
|
texts,
|
||||||
|
model=args.embedding_model, # no host is used for openai
|
||||||
|
)
|
||||||
|
if args.llm_binding == "azure_openai"
|
||||||
|
else openai_embedding(
|
||||||
|
texts,
|
||||||
|
model=args.embedding_model, # no host is used for openai
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -568,12 +585,10 @@ def create_app(args):
|
|||||||
"llm_binding": args.llm_binding,
|
"llm_binding": args.llm_binding,
|
||||||
"llm_binding_host": args.llm_binding_host,
|
"llm_binding_host": args.llm_binding_host,
|
||||||
"llm_model": args.llm_model,
|
"llm_model": args.llm_model,
|
||||||
|
|
||||||
# embedding model configuration binding/host address (if applicable)/model (if applicable)
|
# embedding model configuration binding/host address (if applicable)/model (if applicable)
|
||||||
"embedding_binding": args.embedding_binding,
|
"embedding_binding": args.embedding_binding,
|
||||||
"embedding_binding_host": args.embedding_binding_host,
|
"embedding_binding_host": args.embedding_binding_host,
|
||||||
"embedding_model": args.embedding_model,
|
"embedding_model": args.embedding_model,
|
||||||
|
|
||||||
"max_tokens": args.max_tokens,
|
"max_tokens": args.max_tokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -592,10 +607,12 @@ def main():
|
|||||||
"port": args.port,
|
"port": args.port,
|
||||||
}
|
}
|
||||||
if args.ssl:
|
if args.ssl:
|
||||||
uvicorn_config.update({
|
uvicorn_config.update(
|
||||||
|
{
|
||||||
"ssl_certfile": args.ssl_certfile,
|
"ssl_certfile": args.ssl_certfile,
|
||||||
"ssl_keyfile": args.ssl_keyfile,
|
"ssl_keyfile": args.ssl_keyfile,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
uvicorn.run(**uvicorn_config)
|
uvicorn.run(**uvicorn_config)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -336,7 +336,6 @@ async def hf_model_if_cache(
|
|||||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ollama_model_if_cache(
|
async def ollama_model_if_cache(
|
||||||
model,
|
model,
|
||||||
prompt,
|
prompt,
|
||||||
@@ -411,6 +410,7 @@ async def lollms_model_if_cache(
|
|||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{base_url}/lollms_generate", json=request_data
|
f"{base_url}/lollms_generate", json=request_data
|
||||||
|
Reference in New Issue
Block a user