Merge branch 'main' into main

This commit is contained in:
zrguo
2024-12-06 11:38:27 +08:00
committed by GitHub
3 changed files with 55 additions and 19 deletions

View File

@@ -596,7 +596,11 @@ if __name__ == "__main__":
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` | | **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` | | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
| **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` | | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters:
- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` |
## API Server Implementation ## API Server Implementation

View File

@@ -143,7 +143,7 @@ class OracleDB:
data = None data = None
return data return data
async def execute(self, sql: str, data: list | dict = None): async def execute(self, sql: str, data: Union[list, dict] = None):
# logger.info("go into OracleDB execute method") # logger.info("go into OracleDB execute method")
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:

View File

@@ -1,12 +1,16 @@
import os import base64
import copy import copy
from functools import lru_cache
import json import json
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any, Union
import aioboto3 import aioboto3
import aiohttp import aiohttp
import numpy as np import numpy as np
import ollama import ollama
import torch
from openai import ( from openai import (
AsyncOpenAI, AsyncOpenAI,
APIConnectionError, APIConnectionError,
@@ -14,10 +18,7 @@ from openai import (
Timeout, Timeout,
AsyncAzureOpenAI, AsyncAzureOpenAI,
) )
from pydantic import BaseModel, Field
import base64
import struct
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@@ -25,9 +26,7 @@ from tenacity import (
retry_if_exception_type, retry_if_exception_type,
) )
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any, Union
from .base import BaseKVStorage from .base import BaseKVStorage
from .utils import ( from .utils import (
compute_args_hash, compute_args_hash,
@@ -73,7 +72,11 @@ async def openai_complete_if_cache(
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -93,7 +96,6 @@ async def openai_complete_if_cache(
return best_cached_response return best_cached_response
else: else:
# Use regular cache # Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
@@ -166,7 +168,11 @@ async def azure_openai_complete_if_cache(
messages.extend(history_messages) messages.extend(history_messages)
if prompt is not None: if prompt is not None:
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -185,7 +191,7 @@ async def azure_openai_complete_if_cache(
if best_cached_response is not None: if best_cached_response is not None:
return best_cached_response return best_cached_response
else: else:
args_hash = compute_args_hash(model, messages) # Use regular cache
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
@@ -278,6 +284,9 @@ async def bedrock_complete_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -297,7 +306,6 @@ async def bedrock_complete_if_cache(
return best_cached_response return best_cached_response
else: else:
# Use regular cache # Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
@@ -350,6 +358,11 @@ def initialize_hf_model(model_name):
return hf_model, hf_tokenizer return hf_model, hf_tokenizer
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def hf_model_if_cache( async def hf_model_if_cache(
model, model,
prompt, prompt,
@@ -367,6 +380,9 @@ async def hf_model_if_cache(
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -386,7 +402,6 @@ async def hf_model_if_cache(
return best_cached_response return best_cached_response
else: else:
# Use regular cache # Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
@@ -455,6 +470,11 @@ async def hf_model_if_cache(
return response_text return response_text
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def ollama_model_if_cache( async def ollama_model_if_cache(
model, model,
prompt, prompt,
@@ -476,7 +496,11 @@ async def ollama_model_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -496,7 +520,6 @@ async def ollama_model_if_cache(
return best_cached_response return best_cached_response
else: else:
# Use regular cache # Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
@@ -561,6 +584,11 @@ def initialize_lmdeploy_pipeline(
return lmdeploy_pipe return lmdeploy_pipe
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def lmdeploy_model_if_cache( async def lmdeploy_model_if_cache(
model, model,
prompt, prompt,
@@ -639,7 +667,11 @@ async def lmdeploy_model_if_cache(
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -659,7 +691,6 @@ async def lmdeploy_model_if_cache(
return best_cached_response return best_cached_response
else: else:
# Use regular cache # Use regular cache
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
@@ -850,7 +881,8 @@ async def openai_embedding(
) )
async def nvidia_openai_embedding( async def nvidia_openai_embedding(
texts: list[str], texts: list[str],
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
base_url: str = "https://integrate.api.nvidia.com/v1", base_url: str = "https://integrate.api.nvidia.com/v1",
api_key: str = None, api_key: str = None,
input_type: str = "passage", # query for retrieval, passage for embedding input_type: str = "passage", # query for retrieval, passage for embedding