Add support for OpenAI Compatible Streaming output

This commit is contained in:
partoneplay
2024-12-07 14:41:09 +08:00
parent 50a17bb4f9
commit a8e09ba6c5
2 changed files with 112 additions and 25 deletions

View File

@@ -91,26 +91,40 @@ async def openai_complete_if_cache(
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
content = response.choices[0].message.content
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
if hasattr(response, "__aiter__"):
return content
async def inner():
async for chunk in response:
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
yield content
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = content.encode("utf-8").decode("unicode_escape")
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=content,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return content
@retry(
@@ -431,7 +445,7 @@ async def ollama_model_if_cache(
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """
"""cannot cache stream response"""
async def inner():
async for chunk in response:
@@ -613,6 +627,22 @@ class GPTKeywordExtractionFormat(BaseModel):
low_level_keywords: List[str]
async def openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await openai_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
@@ -1089,12 +1119,14 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
mode_cache[cache_data.args_hash] = {
"return": cache_data.content,
"model": cache_data.model,
"embedding": cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None,
"embedding_shape": cache_data.quantized.shape
if cache_data.quantized is not None
else None,
"embedding": (
cache_data.quantized.tobytes().hex()
if cache_data.quantized is not None
else None
),
"embedding_shape": (
cache_data.quantized.shape if cache_data.quantized is not None else None
),
"embedding_min": cache_data.min_val,
"embedding_max": cache_data.max_val,
"original_prompt": cache_data.prompt,