fastapi接收环境变量EMBEDDING_MODEL、LLM_MODEL、OPENAI_API_KEY、OPENAI_BASE_URL以自定模型

This commit is contained in:
90houlaoheshang
2024-11-05 15:13:48 +08:00
parent f60368f8db
commit 9bab630059

View File

@@ -21,19 +21,18 @@ print(f"WORKING_DIR: {WORKING_DIR}")
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
# LLM model function # LLM model function
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
return await openai_complete_if_cache( return await openai_complete_if_cache(
"gpt-4o-mini", os.environ.get("LLM_MODEL", "gpt-4o-mini"),
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
**kwargs, **kwargs,
) )
@@ -44,21 +43,28 @@ async def llm_model_func(
async def embedding_func(texts: list[str]) -> np.ndarray: async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding( return await openai_embedding(
texts, texts,
model="text-embedding-3-large", model=os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large"),
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
) )
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
print(f"{embedding_dim=}")
return embedding_dim
# Initialize RAG instance # Initialize RAG instance
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(embedding_dim=asyncio.run(get_embedding_dim()),
embedding_dim=3072, max_token_size=8192, func=embedding_func max_token_size=8192,
), func=embedding_func),
) )
# Data models # Data models