Merge pull request #417 from partoneplay/main
Add support for OpenAI Compatible Streaming output and delete unreachable code
This commit is contained in:
55
examples/lightrag_openai_compatible_stream_demo.py
Normal file
55
examples/lightrag_openai_compatible_stream_demo.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
from lightrag import LightRAG
|
||||||
|
from lightrag.llm import openai_complete, openai_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from lightrag.lightrag import always_get_an_event_loop
|
||||||
|
from lightrag import QueryParam
|
||||||
|
|
||||||
|
# WorkingDir
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
WORKING_DIR = os.path.join(ROOT_DIR, "dickens")
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
print(f"WorkingDir: {WORKING_DIR}")
|
||||||
|
|
||||||
|
api_key = "empty"
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=openai_complete,
|
||||||
|
llm_model_name="qwen2.5-14b-instruct@4bit",
|
||||||
|
llm_model_max_async=4,
|
||||||
|
llm_model_max_token_size=32768,
|
||||||
|
llm_model_kwargs={"base_url": "http://127.0.0.1:1234/v1", "api_key": api_key},
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=1024,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=lambda texts: openai_embedding(
|
||||||
|
texts=texts,
|
||||||
|
model="text-embedding-bge-m3",
|
||||||
|
base_url="http://127.0.0.1:1234/v1",
|
||||||
|
api_key=api_key,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
resp = rag.query(
|
||||||
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode="hybrid", stream=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def print_stream(stream):
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk:
|
||||||
|
print(chunk, end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
loop = always_get_an_event_loop()
|
||||||
|
if inspect.isasyncgen(resp):
|
||||||
|
loop.run_until_complete(print_stream(resp))
|
||||||
|
else:
|
||||||
|
print(resp)
|
@@ -76,10 +76,23 @@ async def openai_complete_if_cache(
|
|||||||
response = await openai_async_client.chat.completions.create(
|
response = await openai_async_client.chat.completions.create(
|
||||||
model=model, messages=messages, **kwargs
|
model=model, messages=messages, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(response, "__aiter__"):
|
||||||
|
|
||||||
|
async def inner():
|
||||||
|
async for chunk in response:
|
||||||
|
content = chunk.choices[0].delta.content
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
if r"\u" in content:
|
||||||
|
content = content.encode("utf-8").decode("unicode_escape")
|
||||||
|
yield content
|
||||||
|
|
||||||
|
return inner()
|
||||||
|
else:
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if r"\u" in content:
|
if r"\u" in content:
|
||||||
content = content.encode("utf-8").decode("unicode_escape")
|
content = content.encode("utf-8").decode("unicode_escape")
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
@@ -447,6 +460,22 @@ class GPTKeywordExtractionFormat(BaseModel):
|
|||||||
low_level_keywords: List[str]
|
low_level_keywords: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
async def openai_complete(
|
||||||
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||||
|
if keyword_extraction:
|
||||||
|
kwargs["response_format"] = "json"
|
||||||
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
model_name,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def gpt_4o_complete(
|
async def gpt_4o_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@@ -488,7 +488,7 @@ class CacheData:
|
|||||||
|
|
||||||
|
|
||||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||||
if hashing_kv is None:
|
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
|
||||||
return
|
return
|
||||||
|
|
||||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||||
|
Reference in New Issue
Block a user