Merge branch 'main' into main
This commit is contained in:
@@ -596,7 +596,11 @@ if __name__ == "__main__":
|
||||
| **enable\_llm\_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
|
||||
| **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese"}`: sets example limit and output language | `example_number: all examples, language: English` |
|
||||
| **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` |
|
||||
| **embedding\_cache\_config** | `dict` | Configuration for embedding cache. Includes `enabled` (bool) to toggle cache and `similarity_threshold` (float) for cache retrieval | `{"enabled": False, "similarity_threshold": 0.95}` |
|
||||
| **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains two parameters:
|
||||
- `enabled`: Boolean value to enable/disable caching functionality. When enabled, questions and answers will be cached.
|
||||
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
|
||||
|
||||
Default: `{"enabled": False, "similarity_threshold": 0.95}` | `{"enabled": False, "similarity_threshold": 0.95}` |
|
||||
|
||||
## API Server Implementation
|
||||
|
||||
|
@@ -143,7 +143,7 @@ class OracleDB:
|
||||
data = None
|
||||
return data
|
||||
|
||||
async def execute(self, sql: str, data: list | dict = None):
|
||||
async def execute(self, sql: str, data: Union[list, dict] = None):
|
||||
# logger.info("go into OracleDB execute method")
|
||||
try:
|
||||
async with self.pool.acquire() as connection:
|
||||
|
@@ -1,12 +1,16 @@
|
||||
import os
|
||||
import base64
|
||||
import copy
|
||||
from functools import lru_cache
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from typing import List, Dict, Callable, Any, Union
|
||||
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import ollama
|
||||
|
||||
import torch
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
APIConnectionError,
|
||||
@@ -14,10 +18,7 @@ from openai import (
|
||||
Timeout,
|
||||
AsyncAzureOpenAI,
|
||||
)
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
@@ -25,9 +26,7 @@ from tenacity import (
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Callable, Any, Union
|
||||
|
||||
from .base import BaseKVStorage
|
||||
from .utils import (
|
||||
compute_args_hash,
|
||||
@@ -73,7 +72,11 @@ async def openai_complete_if_cache(
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
@@ -93,7 +96,6 @@ async def openai_complete_if_cache(
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
@@ -166,7 +168,11 @@ async def azure_openai_complete_if_cache(
|
||||
messages.extend(history_messages)
|
||||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
@@ -185,7 +191,7 @@ async def azure_openai_complete_if_cache(
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response
|
||||
else:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
# Use regular cache
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
@@ -278,6 +284,9 @@ async def bedrock_complete_if_cache(
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
@@ -297,7 +306,6 @@ async def bedrock_complete_if_cache(
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
@@ -350,6 +358,11 @@ def initialize_hf_model(model_name):
|
||||
return hf_model, hf_tokenizer
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def hf_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
@@ -367,6 +380,9 @@ async def hf_model_if_cache(
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
@@ -386,7 +402,6 @@ async def hf_model_if_cache(
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
@@ -455,6 +470,11 @@ async def hf_model_if_cache(
|
||||
return response_text
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def ollama_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
@@ -476,7 +496,11 @@ async def ollama_model_if_cache(
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
@@ -496,7 +520,6 @@ async def ollama_model_if_cache(
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
@@ -561,6 +584,11 @@ def initialize_lmdeploy_pipeline(
|
||||
return lmdeploy_pipe
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def lmdeploy_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
@@ -639,7 +667,11 @@ async def lmdeploy_model_if_cache(
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
# Calculate args_hash only when using cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
@@ -659,7 +691,6 @@ async def lmdeploy_model_if_cache(
|
||||
return best_cached_response
|
||||
else:
|
||||
# Use regular cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
@@ -850,7 +881,8 @@ async def openai_embedding(
|
||||
)
|
||||
async def nvidia_openai_embedding(
|
||||
texts: list[str],
|
||||
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
|
||||
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
||||
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
api_key: str = None,
|
||||
input_type: str = "passage", # query for retrieval, passage for embedding
|
||||
|
Reference in New Issue
Block a user