diff --git a/examples/lightrag_openai_compatible_stream_demo.py b/examples/lightrag_openai_compatible_stream_demo.py new file mode 100644 index 00000000..9345ada5 --- /dev/null +++ b/examples/lightrag_openai_compatible_stream_demo.py @@ -0,0 +1,55 @@ +import os +import inspect +from lightrag import LightRAG +from lightrag.llm import openai_complete, openai_embedding +from lightrag.utils import EmbeddingFunc +from lightrag.lightrag import always_get_an_event_loop +from lightrag import QueryParam + +# WorkingDir +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +WORKING_DIR = os.path.join(ROOT_DIR, "dickens") +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) +print(f"WorkingDir: {WORKING_DIR}") + +api_key = "empty" +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=openai_complete, + llm_model_name="qwen2.5-14b-instruct@4bit", + llm_model_max_async=4, + llm_model_max_token_size=32768, + llm_model_kwargs={"base_url": "http://127.0.0.1:1234/v1", "api_key": api_key}, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: openai_embedding( + texts=texts, + model="text-embedding-bge-m3", + base_url="http://127.0.0.1:1234/v1", + api_key=api_key, + ), + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +resp = rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid", stream=True), +) + + +async def print_stream(stream): + async for chunk in stream: + if chunk: + print(chunk, end="", flush=True) + + +loop = always_get_an_event_loop() +if inspect.isasyncgen(resp): + loop.run_until_complete(print_stream(resp)) +else: + print(resp) diff --git a/lightrag/llm.py b/lightrag/llm.py index d02d5350..05f4bf00 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -91,26 +91,40 @@ async def openai_complete_if_cache( response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) - content = response.choices[0].message.content - if r"\u" in content: - content = content.encode("utf-8").decode("unicode_escape") - # Save to cache - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=content, - model=model, - prompt=prompt, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode=mode, - ), - ) + if hasattr(response, "__aiter__"): - return content + async def inner(): + async for chunk in response: + content = chunk.choices[0].delta.content + if content is None: + continue + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") + yield content + + return inner() + else: + content = response.choices[0].message.content + if r"\u" in content: + content = content.encode("utf-8").decode("unicode_escape") + + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=content, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + + return content @retry( @@ -431,7 +445,7 @@ async def ollama_model_if_cache( response = await ollama_client.chat(model=model, messages=messages, **kwargs) if stream: - """ cannot cache stream response """ + """cannot cache stream response""" async def inner(): async for chunk in response: @@ -613,6 +627,22 @@ class GPTKeywordExtractionFormat(BaseModel): low_level_keywords: List[str] +async def openai_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> Union[str, AsyncIterator[str]]: + keyword_extraction = kwargs.pop("keyword_extraction", None) + if keyword_extraction: + kwargs["response_format"] = "json" + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + return await openai_complete_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: @@ -1089,12 +1119,14 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): mode_cache[cache_data.args_hash] = { "return": cache_data.content, "model": cache_data.model, - "embedding": cache_data.quantized.tobytes().hex() - if cache_data.quantized is not None - else None, - "embedding_shape": cache_data.quantized.shape - if cache_data.quantized is not None - else None, + "embedding": ( + cache_data.quantized.tobytes().hex() + if cache_data.quantized is not None + else None + ), + "embedding_shape": ( + cache_data.quantized.shape if cache_data.quantized is not None else None + ), "embedding_min": cache_data.min_val, "embedding_max": cache_data.max_val, "original_prompt": cache_data.prompt,