diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 268efc1d..0d154b38 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -3,10 +3,10 @@ from pydantic import BaseModel import logging import argparse from lightrag import LightRAG, QueryParam -# from lightrag.llm import lollms_model_complete, lollms_embed -# from lightrag.llm import ollama_model_complete, ollama_embed, openai_embedding -from lightrag.llm import openai_complete_if_cache, ollama_embedding -# from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding +from lightrag.llm import lollms_model_complete, lollms_embed +from lightrag.llm import ollama_model_complete, ollama_embed +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding from lightrag.utils import EmbeddingFunc from typing import Optional, List, Union @@ -24,28 +24,13 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm -from dotenv import load_dotenv -load_dotenv() - -async def llm_model_func( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - return await openai_complete_if_cache( - "deepseek-chat", - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - api_key=os.getenv("DEEPSEEK_API_KEY"), - base_url=os.getenv("DEEPSEEK_ENDPOINT"), - **kwargs, - ) def get_default_host(binding_type: str) -> str: default_hosts = { - "ollama": "http://m4.lan.znipower.com:11434", + "ollama": "http://localhost:11434", "lollms": "http://localhost:9600", "azure_openai": "https://api.openai.com/v1", - "openai": os.getenv("DEEPSEEK_ENDPOINT"), + "openai": "https://api.openai.com/v1", } return default_hosts.get( binding_type, "http://localhost:11434" @@ -334,12 +319,44 @@ def create_app(args): # Initialize RAG rag = LightRAG( working_dir=args.working_dir, - llm_model_func=llm_model_func, + llm_model_func=lollms_model_complete + if args.llm_binding == "lollms" + else ollama_model_complete + if args.llm_binding == "ollama" + else azure_openai_complete_if_cache + if args.llm_binding == "azure_openai" + else openai_complete_if_cache, + llm_model_name=args.llm_model, + llm_model_max_async=args.max_async, + llm_model_max_token_size=args.max_tokens, + llm_model_kwargs={ + "host": args.llm_binding_host, + "timeout": args.timeout, + "options": {"num_ctx": args.max_tokens}, + }, embedding_func=EmbeddingFunc( - embedding_dim=1024, - max_token_size=8192, - func=lambda texts: ollama_embedding( - texts, embed_model="bge-m3:latest", host="http://m4.lan.znipower.com:11434" + embedding_dim=args.embedding_dim, + max_token_size=args.max_embed_tokens, + func=lambda texts: lollms_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.llm_binding == "lollms" + else ollama_embed( + texts, + embed_model=args.embedding_model, + host=args.embedding_binding_host, + ) + if args.llm_binding == "ollama" + else azure_openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai + ) + if args.llm_binding == "azure_openai" + else openai_embedding( + texts, + model=args.embedding_model, # no host is used for openai ), ), )