Add huggingface model support

This commit is contained in:
LarFii
2024-10-15 19:40:08 +08:00
parent af997c02c2
commit ea126a7108
11 changed files with 100 additions and 56 deletions

View File

@@ -3,7 +3,8 @@ import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast
from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
from .operate import (
@@ -11,7 +12,7 @@ from .operate import (
extract_entities,
local_query,
global_query,
hybird_query,
hybrid_query,
naive_query,
)
@@ -38,15 +39,14 @@ from .base import (
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
# If there is already an event loop, use it.
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
except RuntimeError:
# If in a sub-thread, create a new event loop.
logger.info("Creating a new event loop in a sub-thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
@dataclass
class LightRAG:
working_dir: str = field(
@@ -77,6 +77,9 @@ class LightRAG:
)
# text embedding
tokenizer: Any = None
embed_model: Any = None
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
embedding_batch_num: int = 32
@@ -100,6 +103,13 @@ class LightRAG:
convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self):
if callable(self.embedding_func) and self.embedding_func.__name__ == 'hf_embedding':
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
if self.embed_model is None:
self.embed_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
@@ -130,8 +140,11 @@ class LightRAG:
namespace="chunk_entity_relation", global_config=asdict(self)
)
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
lambda texts: self.embedding_func(texts, self.tokenizer, self.embed_model)
if callable(self.embedding_func) and self.embedding_func.__name__ == 'hf_embedding'
else self.embedding_func(texts)
)
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
@@ -267,8 +280,8 @@ class LightRAG:
param,
asdict(self),
)
elif param.mode == "hybird":
response = await hybird_query(
elif param.mode == "hybrid":
response = await hybrid_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,