Add HF Support
This commit is contained in:
@@ -5,7 +5,7 @@ from datetime import datetime
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Type, cast
|
from typing import Type, cast
|
||||||
|
|
||||||
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
|
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model,hf_embedding
|
||||||
from .operate import (
|
from .operate import (
|
||||||
chunking_by_token_size,
|
chunking_by_token_size,
|
||||||
extract_entities,
|
extract_entities,
|
||||||
@@ -77,12 +77,13 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# text embedding
|
# text embedding
|
||||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
embedding_func: EmbeddingFunc = field(default_factory=lambda: hf_embedding)#openai_embedding
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = 32
|
||||||
embedding_func_max_async: int = 16
|
embedding_func_max_async: int = 16
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
llm_model_func: callable = gpt_4o_mini_complete
|
llm_model_func: callable = hf_model#gpt_4o_mini_complete
|
||||||
|
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||||
llm_model_max_token_size: int = 32768
|
llm_model_max_token_size: int = 32768
|
||||||
llm_model_max_async: int = 16
|
llm_model_max_async: int = 16
|
||||||
|
|
||||||
|
@@ -7,10 +7,12 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
)
|
)
|
||||||
|
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
from .base import BaseKVStorage
|
from .base import BaseKVStorage
|
||||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||||
|
import copy
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
@@ -42,6 +44,52 @@ async def openai_complete_if_cache(
|
|||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
async def hf_model_if_cache(
|
||||||
|
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
model_name = model
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
|
||||||
|
if hf_tokenizer.pad_token == None:
|
||||||
|
# print("use eos token")
|
||||||
|
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||||
|
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
|
||||||
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.extend(history_messages)
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
if hashing_kv is not None:
|
||||||
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
|
if if_cache_return is not None:
|
||||||
|
return if_cache_return["return"]
|
||||||
|
input_prompt = ''
|
||||||
|
try:
|
||||||
|
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
ori_message = copy.deepcopy(messages)
|
||||||
|
if messages[0]['role'] == "system":
|
||||||
|
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
|
||||||
|
messages = messages[1:]
|
||||||
|
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
except:
|
||||||
|
len_message = len(ori_message)
|
||||||
|
for msgid in range(len_message):
|
||||||
|
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
|
||||||
|
|
||||||
|
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
|
||||||
|
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
|
||||||
|
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
if hashing_kv is not None:
|
||||||
|
await hashing_kv.upsert(
|
||||||
|
{args_hash: {"return": response_text, "model": model}}
|
||||||
|
)
|
||||||
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
async def gpt_4o_complete(
|
async def gpt_4o_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -65,6 +113,20 @@ async def gpt_4o_mini_complete(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def hf_model(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
input_string = kwargs.get('model_name', 'google/gemma-2-2b-it')
|
||||||
|
return await hf_model_if_cache(
|
||||||
|
input_string,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
@@ -78,6 +140,24 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
|
|||||||
)
|
)
|
||||||
return np.array([dp.embedding for dp in response.data])
|
return np.array([dp.embedding for dp in response.data])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
global EMBED_MODEL
|
||||||
|
global tokenizer
|
||||||
|
EMBED_MODEL = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
@wrap_embedding_func_with_attrs(
|
||||||
|
embedding_dim=384,
|
||||||
|
max_token_size=5000,
|
||||||
|
)
|
||||||
|
async def hf_embedding(texts: list[str]) -> np.ndarray:
|
||||||
|
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = EMBED_MODEL(input_ids)
|
||||||
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
|
return embeddings.detach().numpy()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
@@ -3,7 +3,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
import warnings
|
||||||
from .utils import (
|
from .utils import (
|
||||||
logger,
|
logger,
|
||||||
clean_str,
|
clean_str,
|
||||||
@@ -398,10 +398,15 @@ async def local_query(
|
|||||||
keywords = keywords_data.get("low_level_keywords", [])
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ', '.join(keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
try:
|
||||||
|
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
||||||
|
keywords_data = json.loads(result)
|
||||||
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
|
keywords = ', '.join(keywords)
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
print(f"JSON parsing error: {e}")
|
except json.JSONDecodeError as e:
|
||||||
return PROMPTS["fail_response"]
|
print(f"JSON parsing error: {e}")
|
||||||
|
return PROMPTS["fail_response"]
|
||||||
context = await _build_local_query_context(
|
context = await _build_local_query_context(
|
||||||
keywords,
|
keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
@@ -421,6 +426,9 @@ async def local_query(
|
|||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
|
if len(response)>len(sys_prompt):
|
||||||
|
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _build_local_query_context(
|
async def _build_local_query_context(
|
||||||
@@ -617,9 +625,16 @@ async def global_query(
|
|||||||
keywords = keywords_data.get("high_level_keywords", [])
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ', '.join(keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
# Handle parsing error
|
try:
|
||||||
print(f"JSON parsing error: {e}")
|
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
||||||
return PROMPTS["fail_response"]
|
keywords_data = json.loads(result)
|
||||||
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
|
keywords = ', '.join(keywords)
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# Handle parsing error
|
||||||
|
print(f"JSON parsing error: {e}")
|
||||||
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
context = await _build_global_query_context(
|
context = await _build_global_query_context(
|
||||||
keywords,
|
keywords,
|
||||||
@@ -643,6 +658,9 @@ async def global_query(
|
|||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
|
if len(response)>len(sys_prompt):
|
||||||
|
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _build_global_query_context(
|
async def _build_global_query_context(
|
||||||
@@ -822,8 +840,8 @@ async def hybird_query(
|
|||||||
|
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
kw_prompt = kw_prompt_temp.format(query=query)
|
kw_prompt = kw_prompt_temp.format(query=query)
|
||||||
|
|
||||||
result = await use_model_func(kw_prompt)
|
result = await use_model_func(kw_prompt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||||
@@ -831,10 +849,18 @@ async def hybird_query(
|
|||||||
hl_keywords = ', '.join(hl_keywords)
|
hl_keywords = ', '.join(hl_keywords)
|
||||||
ll_keywords = ', '.join(ll_keywords)
|
ll_keywords = ', '.join(ll_keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
try:
|
||||||
|
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
||||||
|
keywords_data = json.loads(result)
|
||||||
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||||
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||||
|
hl_keywords = ', '.join(hl_keywords)
|
||||||
|
ll_keywords = ', '.join(ll_keywords)
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
print(f"JSON parsing error: {e}")
|
except json.JSONDecodeError as e:
|
||||||
return PROMPTS["fail_response"]
|
print(f"JSON parsing error: {e}")
|
||||||
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
low_level_context = await _build_local_query_context(
|
low_level_context = await _build_local_query_context(
|
||||||
ll_keywords,
|
ll_keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
@@ -851,7 +877,7 @@ async def hybird_query(
|
|||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
context = combine_contexts(high_level_context, low_level_context)
|
context = combine_contexts(high_level_context, low_level_context)
|
||||||
|
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
@@ -867,10 +893,13 @@ async def hybird_query(
|
|||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
|
if len(response)>len(sys_prompt):
|
||||||
|
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def combine_contexts(high_level_context, low_level_context):
|
def combine_contexts(high_level_context, low_level_context):
|
||||||
# Function to extract entities, relationships, and sources from context strings
|
# Function to extract entities, relationships, and sources from context strings
|
||||||
|
|
||||||
def extract_sections(context):
|
def extract_sections(context):
|
||||||
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||||
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
||||||
@@ -883,8 +912,21 @@ def combine_contexts(high_level_context, low_level_context):
|
|||||||
return entities, relationships, sources
|
return entities, relationships, sources
|
||||||
|
|
||||||
# Extract sections from both contexts
|
# Extract sections from both contexts
|
||||||
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
|
||||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
if high_level_context==None:
|
||||||
|
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
||||||
|
hl_entities, hl_relationships, hl_sources = '','',''
|
||||||
|
else:
|
||||||
|
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
||||||
|
|
||||||
|
|
||||||
|
if low_level_context==None:
|
||||||
|
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
||||||
|
ll_entities, ll_relationships, ll_sources = '','',''
|
||||||
|
else:
|
||||||
|
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Combine and deduplicate the entities
|
# Combine and deduplicate the entities
|
||||||
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
||||||
@@ -917,6 +959,7 @@ async def naive_query(
|
|||||||
global_config: dict,
|
global_config: dict,
|
||||||
):
|
):
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
|
use_model_name = global_config['llm_model_name']
|
||||||
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
results = await chunks_vdb.query(query, top_k=query_param.top_k)
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
@@ -939,6 +982,11 @@ async def naive_query(
|
|||||||
response = await use_model_func(
|
response = await use_model_func(
|
||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
|
model_name = use_model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(response)>len(sys_prompt):
|
||||||
|
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user