Add huggingface model support

This commit is contained in:
LarFii
2024-10-15 19:40:08 +08:00
parent 2d0d25005b
commit 1e78068268
11 changed files with 100 additions and 56 deletions

View File

@@ -59,8 +59,8 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybird search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybird"))) print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
``` ```
Batch Insert Batch Insert
```python ```python
@@ -287,8 +287,8 @@ def extract_queries(file_path):
├── examples ├── examples
├── batch_eval.py ├── batch_eval.py
├── generate_query.py ├── generate_query.py
├── insert.py ├── lightrag_openai_demo.py
└── query.py └── lightrag_hf_demo.py
├── lightrag ├── lightrag
├── __init__.py ├── __init__.py
├── base.py ├── base.py

View File

@@ -1,18 +0,0 @@
import os
import sys
from lightrag import LightRAG
# os.environ["OPENAI_API_KEY"] = ""
WORKING_DIR = ""
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(working_dir=WORKING_DIR)
with open('./text.txt', 'r') as f:
text = f.read()
rag.insert(text)

View File

@@ -0,0 +1,36 @@
import os
import sys
from lightrag import LightRAG, QueryParam
from lightrag.llm import hf_model_complete, hf_embedding
from transformers import AutoModel,AutoTokenizer
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete,
llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
embedding_func=hf_embedding,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))

View File

@@ -0,0 +1,33 @@
import os
import sys
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from transformers import AutoModel,AutoTokenizer
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_complete
# llm_model_func=gpt_4o_mini_complete
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))

View File

@@ -1,16 +0,0 @@
import os
import sys
from lightrag import LightRAG, QueryParam
# os.environ["OPENAI_API_KEY"] = ""
WORKING_DIR = ""
rag = LightRAG(working_dir=WORKING_DIR)
mode = 'global'
query_param = QueryParam(mode=mode)
result = rag.query("", param=query_param)
print(result)

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG, QueryParam from .lightrag import LightRAG, QueryParam
__version__ = "0.0.3" __version__ = "0.0.4"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -14,7 +14,7 @@ T = TypeVar("T")
@dataclass @dataclass
class QueryParam: class QueryParam:
mode: Literal["local", "global", "hybird", "naive"] = "global" mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False only_need_context: bool = False
response_type: str = "Multiple Paragraphs" response_type: str = "Multiple Paragraphs"
top_k: int = 60 top_k: int = 60

View File

@@ -3,7 +3,8 @@ import os
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Type, cast from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
from .operate import ( from .operate import (
@@ -11,7 +12,7 @@ from .operate import (
extract_entities, extract_entities,
local_query, local_query,
global_query, global_query,
hybird_query, hybrid_query,
naive_query, naive_query,
) )
@@ -38,15 +39,14 @@ from .base import (
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try: try:
# If there is already an event loop, use it. loop = asyncio.get_running_loop()
loop = asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
# If in a sub-thread, create a new event loop.
logger.info("Creating a new event loop in a sub-thread.") logger.info("Creating a new event loop in a sub-thread.")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
return loop return loop
@dataclass @dataclass
class LightRAG: class LightRAG:
working_dir: str = field( working_dir: str = field(
@@ -77,6 +77,9 @@ class LightRAG:
) )
# text embedding # text embedding
tokenizer: Any = None
embed_model: Any = None
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)# embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)#
embedding_batch_num: int = 32 embedding_batch_num: int = 32
@@ -100,6 +103,13 @@ class LightRAG:
convert_response_to_json_func: callable = convert_response_to_json convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self): def __post_init__(self):
if callable(self.embedding_func) and self.embedding_func.__name__ == 'hf_embedding':
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
if self.embed_model is None:
self.embed_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
log_file = os.path.join(self.working_dir, "lightrag.log") log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file) set_logger(log_file)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")
@@ -130,8 +140,11 @@ class LightRAG:
namespace="chunk_entity_relation", global_config=asdict(self) namespace="chunk_entity_relation", global_config=asdict(self)
) )
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func lambda texts: self.embedding_func(texts, self.tokenizer, self.embed_model)
if callable(self.embedding_func) and self.embedding_func.__name__ == 'hf_embedding'
else self.embedding_func(texts)
) )
self.entities_vdb = ( self.entities_vdb = (
self.vector_db_storage_cls( self.vector_db_storage_cls(
namespace="entities", namespace="entities",
@@ -267,8 +280,8 @@ class LightRAG:
param, param,
asdict(self), asdict(self),
) )
elif param.mode == "hybird": elif param.mode == "hybrid":
response = await hybird_query( response = await hybrid_query(
query, query,
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,

View File

@@ -142,18 +142,14 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
global EMBED_MODEL
global tokenizer
EMBED_MODEL = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
@wrap_embedding_func_with_attrs( @wrap_embedding_func_with_attrs(
embedding_dim=384, embedding_dim=384,
max_token_size=5000, max_token_size=5000,
) )
async def hf_embedding(texts: list[str]) -> np.ndarray: async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
with torch.no_grad(): with torch.no_grad():
outputs = EMBED_MODEL(input_ids) outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy() return embeddings.detach().numpy()

View File

@@ -827,7 +827,7 @@ async def _find_related_text_unit_from_relationships(
return all_text_units return all_text_units
async def hybird_query( async def hybrid_query(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage, entities_vdb: BaseVectorStorage,

View File

@@ -52,7 +52,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
if __name__ == "__main__": if __name__ == "__main__":
cls = "agriculture" cls = "agriculture"
mode = "hybird" mode = "hybrid"
WORKING_DIR = "../{cls}" WORKING_DIR = "../{cls}"
rag = LightRAG(working_dir=WORKING_DIR) rag = LightRAG(working_dir=WORKING_DIR)