From 6c29a37f2063283d0748d2139ddccace23a0700b Mon Sep 17 00:00:00 2001 From: magicyuan876 <317617749@qq.com> Date: Fri, 6 Dec 2024 10:28:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20args=5Fhash=E5=9C=A8?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=B8=B8=E8=A7=84=E7=BC=93=E5=AD=98=E6=97=B6?= =?UTF-8?q?=E5=80=99=E6=89=8D=E8=AE=A1=E7=AE=97=E5=AF=BC=E8=87=B4embedding?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E6=97=B6=E6=B2=A1=E6=9C=89=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/llm.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index fdfb70a8..97c903d2 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,12 +1,16 @@ -import os +import base64 import copy -from functools import lru_cache import json +import os +import struct +from functools import lru_cache +from typing import List, Dict, Callable, Any + import aioboto3 import aiohttp import numpy as np import ollama - +import torch from openai import ( AsyncOpenAI, APIConnectionError, @@ -14,10 +18,7 @@ from openai import ( Timeout, AsyncAzureOpenAI, ) - -import base64 -import struct - +from pydantic import BaseModel, Field from tenacity import ( retry, stop_after_attempt, @@ -25,9 +26,7 @@ from tenacity import ( retry_if_exception_type, ) from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -from pydantic import BaseModel, Field -from typing import List, Dict, Callable, Any + from .base import BaseKVStorage from .utils import ( compute_args_hash, @@ -70,7 +69,7 @@ async def openai_complete_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -167,7 +166,7 @@ async def azure_openai_complete_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -281,7 +280,7 @@ async def bedrock_complete_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -378,7 +377,7 @@ async def hf_model_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -496,7 +495,7 @@ async def ollama_model_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -657,7 +656,7 @@ async def lmdeploy_model_if_cache( if hashing_kv is not None: # Calculate args_hash only when using cache args_hash = compute_args_hash(model, messages) - + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} @@ -867,7 +866,8 @@ async def openai_embedding( ) async def nvidia_openai_embedding( texts: list[str], - model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding + model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", + # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding base_url: str = "https://integrate.api.nvidia.com/v1", api_key: str = None, input_type: str = "passage", # query for retrieval, passage for embedding