update hf_model_complete

This commit is contained in:
TianyuFan0504
2024-10-14 20:33:46 +08:00
parent 741953c34b
commit e5876fb225
3 changed files with 6 additions and 7 deletions

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from functools import partial
from typing import Type, cast
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model,hf_embedding
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
from .operate import (
chunking_by_token_size,
extract_entities,
@@ -77,12 +77,13 @@ class LightRAG:
)
# text embedding
embedding_func: EmbeddingFunc = field(default_factory=lambda: hf_embedding)#openai_embedding
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
llm_model_func: callable = hf_model#gpt_4o_mini_complete
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16

View File

@@ -115,10 +115,10 @@ async def gpt_4o_mini_complete(
async def hf_model(
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
input_string = kwargs.get('model_name', 'google/gemma-2-2b-it')
input_string = kwargs['hashing_kv'].global_config['llm_model_name']
return await hf_model_if_cache(
input_string,
prompt,

View File

@@ -959,7 +959,6 @@ async def naive_query(
global_config: dict,
):
use_model_func = global_config["llm_model_func"]
use_model_name = global_config['llm_model_name']
results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return PROMPTS["fail_response"]
@@ -982,7 +981,6 @@ async def naive_query(
response = await use_model_func(
query,
system_prompt=sys_prompt,
model_name = use_model_name
)
if len(response)>len(sys_prompt):