From 8e0877f4b0f9ea50326be5dcf8289bcdffb6442e Mon Sep 17 00:00:00 2001 From: zrguo Date: Fri, 18 Oct 2024 15:32:58 +0800 Subject: [PATCH] Create lightrag_azure_openai_demo.py --- examples/lightrag_azure_openai_demo.py | 125 +++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 examples/lightrag_azure_openai_demo.py diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py new file mode 100644 index 00000000..62282a25 --- /dev/null +++ b/examples/lightrag_azure_openai_demo.py @@ -0,0 +1,125 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.utils import EmbeddingFunc +import numpy as np +from dotenv import load_dotenv +import aiohttp +import logging + +logging.basicConfig(level=logging.INFO) + +load_dotenv() + +AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") +AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") +AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") +AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") + +AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT") +AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION") + +WORKING_DIR = "./dickens" + +if os.path.exists(WORKING_DIR): + import shutil + + shutil.rmtree(WORKING_DIR) + +os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + headers = { + "Content-Type": "application/json", + "api-key": AZURE_OPENAI_API_KEY, + } + endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}" + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + payload = { + "messages": messages, + "temperature": kwargs.get("temperature", 0), + "top_p": kwargs.get("top_p", 1), + "n": kwargs.get("n", 1), + } + + async with aiohttp.ClientSession() as session: + async with session.post(endpoint, headers=headers, json=payload) as response: + if response.status != 200: + raise ValueError( + f"Request failed with status {response.status}: {await response.text()}" + ) + result = await response.json() + return result["choices"][0]["message"]["content"] + + +async def embedding_func(texts: list[str]) -> np.ndarray: + headers = { + "Content-Type": "application/json", + "api-key": AZURE_OPENAI_API_KEY, + } + endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}" + + payload = {"input": texts} + + async with aiohttp.ClientSession() as session: + async with session.post(endpoint, headers=headers, json=payload) as response: + if response.status != 200: + raise ValueError( + f"Request failed with status {response.status}: {await response.text()}" + ) + result = await response.json() + embeddings = [item["embedding"] for item in result["data"]] + return np.array(embeddings) + + +async def test_funcs(): + result = await llm_model_func("How are you?") + print("Resposta do llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("Resultado do embedding_func: ", result.shape) + print("Dimensão da embedding: ", result.shape[1]) + + +asyncio.run(test_funcs()) + +embedding_dimension = 3072 + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), +) + +book1 = open("./book_1.txt", encoding="utf-8") +book2 = open("./book_2.txt", encoding="utf-8") + +rag.insert([book1.read(), book2.read()]) + +query_text = "What are the main themes?" + +print("Result (Naive):") +print(rag.query(query_text, param=QueryParam(mode="naive"))) + +print("\nResult (Local):") +print(rag.query(query_text, param=QueryParam(mode="local"))) + +print("\nResult (Global):") +print(rag.query(query_text, param=QueryParam(mode="global"))) + +print("\nResult (Hybrid):") +print(rag.query(query_text, param=QueryParam(mode="hybrid"))) \ No newline at end of file