This commit is contained in:
LarFii
2024-10-15 21:21:57 +08:00
parent cc4af0d75f
commit e34922d292
4 changed files with 7 additions and 10 deletions

View File

@@ -16,11 +16,13 @@ rag = LightRAG(
llm_model_func=hf_model_complete, llm_model_func=hf_model_complete,
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
embedding_dim=384, embedding_dim=384,
max_token_size=5000, max_token_size=5000,
func=hf_embedding func=lambda texts: hf_embedding(
texts,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
)
), ),
) )

View File

@@ -5,7 +5,7 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
from transformers import AutoModel,AutoTokenizer from transformers import AutoModel,AutoTokenizer
WORKING_DIR = "/home/zrguo/code/myrag/agriculture" WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG, QueryParam from .lightrag import LightRAG, QueryParam
__version__ = "0.0.4" __version__ = "0.0.5"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -141,11 +141,6 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])
@wrap_embedding_func_with_attrs(
embedding_dim=384,
max_token_size=5000,
)
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
with torch.no_grad(): with torch.no_grad():