Add support for Ollama streaming output and integrate Open-WebUI as the chat UI demo

This commit is contained in:
partoneplay
2024-12-06 08:48:55 +08:00
parent 2a2756d9d1
commit 335179196a
5 changed files with 203 additions and 23 deletions

View File

@@ -27,7 +27,7 @@ from tenacity import (
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any
from typing import List, Dict, Callable, Any, Union
from .base import BaseKVStorage
from .utils import (
compute_args_hash,
@@ -37,6 +37,13 @@ from .utils import (
get_best_cached_response,
)
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -454,7 +461,8 @@ async def ollama_model_if_cache(
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
@@ -494,28 +502,39 @@ async def ollama_model_if_cache(
return if_cache_return["return"]
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """
result = response["message"]["content"]
async def inner():
async for chunk in response:
yield chunk["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": result,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val if is_embedding_cache_enabled else None,
"embedding_max": max_val if is_embedding_cache_enabled else None,
"original_prompt": prompt,
return inner()
else:
result = response["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": result,
"model": model,
"embedding": quantized.tobytes().hex()
if is_embedding_cache_enabled
else None,
"embedding_shape": quantized.shape
if is_embedding_cache_enabled
else None,
"embedding_min": min_val
if is_embedding_cache_enabled
else None,
"embedding_max": max_val
if is_embedding_cache_enabled
else None,
"original_prompt": prompt,
}
}
}
)
return result
)
return result
@lru_cache(maxsize=1)
@@ -785,7 +804,7 @@ async def hf_model_complete(
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"