diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 25199888..9c34a607 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -5,7 +5,7 @@ from datetime import datetime from functools import partial from typing import Type, cast -from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model,hf_embedding +from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding from .operate import ( chunking_by_token_size, extract_entities, @@ -77,12 +77,13 @@ class LightRAG: ) # text embedding - embedding_func: EmbeddingFunc = field(default_factory=lambda: hf_embedding)#openai_embedding + # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) + embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)# embedding_batch_num: int = 32 embedding_func_max_async: int = 16 # LLM - llm_model_func: callable = hf_model#gpt_4o_mini_complete + llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete# llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_max_token_size: int = 32768 llm_model_max_async: int = 16 diff --git a/lightrag/llm.py b/lightrag/llm.py index ac1471c1..5fb27b04 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -115,10 +115,10 @@ async def gpt_4o_mini_complete( -async def hf_model( +async def hf_model_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - input_string = kwargs.get('model_name', 'google/gemma-2-2b-it') + input_string = kwargs['hashing_kv'].global_config['llm_model_name'] return await hf_model_if_cache( input_string, prompt, diff --git a/lightrag/operate.py b/lightrag/operate.py index 21b914f9..a8213a37 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -959,7 +959,6 @@ async def naive_query( global_config: dict, ): use_model_func = global_config["llm_model_func"] - use_model_name = global_config['llm_model_name'] results = await chunks_vdb.query(query, top_k=query_param.top_k) if not len(results): return PROMPTS["fail_response"] @@ -982,7 +981,6 @@ async def naive_query( response = await use_model_func( query, system_prompt=sys_prompt, - model_name = use_model_name ) if len(response)>len(sys_prompt):