From 846937195e8b93bfa3a2ce59b542fdaee17eb5d7 Mon Sep 17 00:00:00 2001 From: 90houlaoheshang <907333918@qq.com> Date: Wed, 6 Nov 2024 11:13:37 +0800 Subject: [PATCH] =?UTF-8?q?=E9=9B=86=E4=B8=AD=E5=A4=84=E7=90=86=E7=8E=AF?= =?UTF-8?q?=E5=A2=83=E5=8F=98=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/lightrag_api_openai_compatible_demo.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index bc56ac59..20a05a5f 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -18,6 +18,13 @@ app = FastAPI(title="LightRAG API", description="API for RAG operations") # Configure working directory WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") print(f"WORKING_DIR: {WORKING_DIR}") +LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o-mini") +print(f"LLM_MODEL: {LLM_MODEL}") +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") +print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) +print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") + if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -29,7 +36,7 @@ async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: return await openai_complete_if_cache( - os.environ.get("LLM_MODEL", "gpt-4o-mini"), + LLM_MODEL, prompt, system_prompt=system_prompt, history_messages=history_messages, @@ -43,7 +50,7 @@ async def llm_model_func( async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embedding( texts, - model=os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large"), + model=EMBEDDING_MODEL, ) @@ -60,7 +67,7 @@ rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc(embedding_dim=asyncio.run(get_embedding_dim()), - max_token_size=os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192), + max_token_size=EMBEDDING_MAX_TOKEN_SIZE, func=embedding_func), )