diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py new file mode 100644 index 00000000..75ecc118 --- /dev/null +++ b/examples/lightrag_openai_compatible_demo.py @@ -0,0 +1,69 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "solar-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar", + **kwargs + ) + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model="solar-embedding-1-large-query", + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar" + ) + +# function test +async def test_funcs(): + result = await llm_model_func("How are you?") + print("llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("embedding_func: ", result) + +asyncio.run(test_funcs()) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=4096, + max_token_size=8192, + func=embedding_func + ) +) + + +with open("./book.txt") as f: + rag.insert(f.read()) + +# Perform naive search +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) + +# Perform local search +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) + +# Perform global search +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) + +# Perform hybrid search +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) diff --git a/lightrag/llm.py b/lightrag/llm.py index bcb7e495..d2ca5344 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -19,9 +19,12 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) async def openai_complete_if_cache( - model, prompt, system_prompt=None, history_messages=[], **kwargs + model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs ) -> str: - openai_async_client = AsyncOpenAI() + if api_key: + os.environ["OPENAI_API_KEY"] = api_key + + openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: @@ -133,10 +136,13 @@ async def hf_model_complete( wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) -async def openai_embedding(texts: list[str]) -> np.ndarray: - openai_async_client = AsyncOpenAI() +async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray: + if api_key: + os.environ["OPENAI_API_KEY"] = api_key + + openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) response = await openai_async_client.embeddings.create( - model="text-embedding-3-small", input=texts, encoding_format="float" + model=model, input=texts, encoding_format="float" ) return np.array([dp.embedding for dp in response.data])