修复 args_hash在使用常规缓存时候才计算导致embedding缓存时没有计算的bug

This commit is contained in:
magicyuan876
2024-12-06 10:28:35 +08:00
parent 6540d11096
commit 6c29a37f20

View File

@@ -1,12 +1,16 @@
import os import base64
import copy import copy
from functools import lru_cache
import json import json
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any
import aioboto3 import aioboto3
import aiohttp import aiohttp
import numpy as np import numpy as np
import ollama import ollama
import torch
from openai import ( from openai import (
AsyncOpenAI, AsyncOpenAI,
APIConnectionError, APIConnectionError,
@@ -14,10 +18,7 @@ from openai import (
Timeout, Timeout,
AsyncAzureOpenAI, AsyncAzureOpenAI,
) )
from pydantic import BaseModel, Field
import base64
import struct
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@@ -25,9 +26,7 @@ from tenacity import (
retry_if_exception_type, retry_if_exception_type,
) )
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any
from .base import BaseKVStorage from .base import BaseKVStorage
from .utils import ( from .utils import (
compute_args_hash, compute_args_hash,
@@ -70,7 +69,7 @@ async def openai_complete_if_cache(
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache # Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -167,7 +166,7 @@ async def azure_openai_complete_if_cache(
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache # Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -281,7 +280,7 @@ async def bedrock_complete_if_cache(
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache # Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -378,7 +377,7 @@ async def hf_model_if_cache(
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache # Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -496,7 +495,7 @@ async def ollama_model_if_cache(
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache # Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -657,7 +656,7 @@ async def lmdeploy_model_if_cache(
if hashing_kv is not None: if hashing_kv is not None:
# Calculate args_hash only when using cache # Calculate args_hash only when using cache
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
# Get embedding cache configuration # Get embedding cache configuration
embedding_cache_config = hashing_kv.global_config.get( embedding_cache_config = hashing_kv.global_config.get(
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
@@ -867,7 +866,8 @@ async def openai_embedding(
) )
async def nvidia_openai_embedding( async def nvidia_openai_embedding(
texts: list[str], texts: list[str],
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
base_url: str = "https://integrate.api.nvidia.com/v1", base_url: str = "https://integrate.api.nvidia.com/v1",
api_key: str = None, api_key: str = None,
input_type: str = "passage", # query for retrieval, passage for embedding input_type: str = "passage", # query for retrieval, passage for embedding