update hf_model_complete
This commit is contained in:
@@ -5,7 +5,7 @@ from datetime import datetime
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Type, cast
|
from typing import Type, cast
|
||||||
|
|
||||||
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model,hf_embedding
|
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
|
||||||
from .operate import (
|
from .operate import (
|
||||||
chunking_by_token_size,
|
chunking_by_token_size,
|
||||||
extract_entities,
|
extract_entities,
|
||||||
@@ -77,12 +77,13 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# text embedding
|
# text embedding
|
||||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: hf_embedding)#openai_embedding
|
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
||||||
|
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = 32
|
||||||
embedding_func_max_async: int = 16
|
embedding_func_max_async: int = 16
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
llm_model_func: callable = hf_model#gpt_4o_mini_complete
|
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
|
||||||
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||||
llm_model_max_token_size: int = 32768
|
llm_model_max_token_size: int = 32768
|
||||||
llm_model_max_async: int = 16
|
llm_model_max_async: int = 16
|
||||||
|
@@ -115,10 +115,10 @@ async def gpt_4o_mini_complete(
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def hf_model(
|
async def hf_model_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
input_string = kwargs.get('model_name', 'google/gemma-2-2b-it')
|
input_string = kwargs['hashing_kv'].global_config['llm_model_name']
|
||||||
return await hf_model_if_cache(
|
return await hf_model_if_cache(
|
||||||
input_string,
|
input_string,
|
||||||
prompt,
|
prompt,
|
||||||
|
@@ -959,7 +959,6 @@ async def naive_query(
|
|||||||
global_config: dict,
|
global_config: dict,
|
||||||
):
|
):
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
use_model_name = global_config['llm_model_name']
|
|
||||||
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
@@ -982,7 +981,6 @@ async def naive_query(
|
|||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
model_name = use_model_name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(response)>len(sys_prompt):
|
if len(response)>len(sys_prompt):
|
||||||
|
Reference in New Issue
Block a user