Add huggingface model support

This commit is contained in:
LarFii
2024-10-15 19:40:08 +08:00
parent 2d0d25005b
commit 1e78068268
11 changed files with 100 additions and 56 deletions

View File

@@ -59,8 +59,8 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybird search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybird")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
```
Batch Insert
```python
@@ -287,8 +287,8 @@ def extract_queries(file_path):
├── examples
├── batch_eval.py
├── generate_query.py
├── insert.py
└── query.py
├── lightrag_openai_demo.py
└── lightrag_hf_demo.py
├── lightrag
├── __init__.py
├── base.py

View File

@@ -1,18 +0,0 @@
import os
import sys
from lightrag import LightRAG
# os.environ["OPENAI_API_KEY"] = ""
WORKING_DIR = ""
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(working_dir=WORKING_DIR)
with open('./text.txt', 'r') as f:
text = f.read()
rag.insert(text)

View File

@@ -0,0 +1,36 @@
import os
import sys
from lightrag import LightRAG, QueryParam
from lightrag.llm import hf_model_complete, hf_embedding
from transformers import AutoModel,AutoTokenizer
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete,
llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
embedding_func=hf_embedding,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))

View File

@@ -0,0 +1,33 @@
import os
import sys
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from transformers import AutoModel,AutoTokenizer
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_complete
# llm_model_func=gpt_4o_mini_complete
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))

View File

@@ -1,16 +0,0 @@
import os
import sys
from lightrag import LightRAG, QueryParam
# os.environ["OPENAI_API_KEY"] = ""
WORKING_DIR = ""
rag = LightRAG(working_dir=WORKING_DIR)
mode = 'global'
query_param = QueryParam(mode=mode)
result = rag.query("", param=query_param)
print(result)

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG, QueryParam
__version__ = "0.0.3"
__version__ = "0.0.4"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -14,7 +14,7 @@ T = TypeVar("T")
@dataclass
class QueryParam:
mode: Literal["local", "global", "hybird", "naive"] = "global"
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
top_k: int = 60

View File

@@ -3,7 +3,8 @@ import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast
from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
from .operate import (
@@ -11,7 +12,7 @@ from .operate import (
extract_entities,
local_query,
global_query,
hybird_query,
hybrid_query,
naive_query,
)
@@ -38,15 +39,14 @@ from .base import (
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
# If there is already an event loop, use it.
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
except RuntimeError:
# If in a sub-thread, create a new event loop.
logger.info("Creating a new event loop in a sub-thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
@dataclass
class LightRAG:
working_dir: str = field(
@@ -77,6 +77,9 @@ class LightRAG:
)
# text embedding
tokenizer: Any = None
embed_model: Any = None
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
embedding_batch_num: int = 32
@@ -100,6 +103,13 @@ class LightRAG:
convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self):
if callable(self.embedding_func) and self.embedding_func.__name__ == 'hf_embedding':
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
if self.embed_model is None:
self.embed_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
@@ -130,8 +140,11 @@ class LightRAG:
namespace="chunk_entity_relation", global_config=asdict(self)
)
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
lambda texts: self.embedding_func(texts, self.tokenizer, self.embed_model)
if callable(self.embedding_func) and self.embedding_func.__name__ == 'hf_embedding'
else self.embedding_func(texts)
)
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
@@ -267,8 +280,8 @@ class LightRAG:
param,
asdict(self),
)
elif param.mode == "hybird":
response = await hybird_query(
elif param.mode == "hybrid":
response = await hybrid_query(
query,
self.chunk_entity_relation_graph,
self.entities_vdb,

View File

@@ -142,18 +142,14 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
global EMBED_MODEL
global tokenizer
EMBED_MODEL = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
@wrap_embedding_func_with_attrs(
embedding_dim=384,
max_token_size=5000,
)
async def hf_embedding(texts: list[str]) -> np.ndarray:
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
with torch.no_grad():
outputs = EMBED_MODEL(input_ids)
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy()

View File

@@ -827,7 +827,7 @@ async def _find_related_text_unit_from_relationships(
return all_text_units
async def hybird_query(
async def hybrid_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,

View File

@@ -52,7 +52,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
if __name__ == "__main__":
cls = "agriculture"
mode = "hybird"
mode = "hybrid"
WORKING_DIR = "../{cls}"
rag = LightRAG(working_dir=WORKING_DIR)