From 9bab6300592d182c8be26ce7e2a675331610ee6b Mon Sep 17 00:00:00 2001 From: 90houlaoheshang <907333918@qq.com> Date: Tue, 5 Nov 2024 15:13:48 +0800 Subject: [PATCH] =?UTF-8?q?fastapi=E6=8E=A5=E6=94=B6=E7=8E=AF=E5=A2=83?= =?UTF-8?q?=E5=8F=98=E9=87=8FEMBEDDING=5FMODEL=E3=80=81LLM=5FMODEL?= =?UTF-8?q?=E3=80=81OPENAI=5FAPI=5FKEY=E3=80=81OPENAI=5FBASE=5FURL?= =?UTF-8?q?=E4=BB=A5=E8=87=AA=E5=AE=9A=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../lightrag_api_openai_compatible_demo.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py index 2cd262bb..7e1f608a 100644 --- a/examples/lightrag_api_openai_compatible_demo.py +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -21,19 +21,18 @@ print(f"WORKING_DIR: {WORKING_DIR}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + # LLM model function async def llm_model_func( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: return await openai_complete_if_cache( - "gpt-4o-mini", + os.environ.get("LLM_MODEL", "gpt-4o-mini"), prompt, system_prompt=system_prompt, history_messages=history_messages, - api_key="YOUR_API_KEY", - base_url="YourURL/v1", **kwargs, ) @@ -44,21 +43,28 @@ async def llm_model_func( async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embedding( texts, - model="text-embedding-3-large", - api_key="YOUR_API_KEY", - base_url="YourURL/v1", + model=os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large"), ) +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + print(f"{embedding_dim=}") + return embedding_dim + + # Initialize RAG instance rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=3072, max_token_size=8192, func=embedding_func - ), + embedding_func=EmbeddingFunc(embedding_dim=asyncio.run(get_embedding_dim()), + max_token_size=8192, + func=embedding_func), ) + # Data models