From f33bcbb32cfa7b4aad8936b564a288018188ecce Mon Sep 17 00:00:00 2001 From: Mario Vignieri Date: Thu, 20 Mar 2025 09:40:56 +0100 Subject: [PATCH] fix hf_embed torch device use MPS or CPU when CUDA is not available -macos users --- lightrag/llm/hf.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py index fb5208b0..954a99b7 100644 --- a/lightrag/llm/hf.py +++ b/lightrag/llm/hf.py @@ -138,16 +138,31 @@ async def hf_model_complete( async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: - device = next(embed_model.parameters()).device + # Detect the appropriate device + if torch.cuda.is_available(): + device = next(embed_model.parameters()).device # Use CUDA if available + elif torch.backends.mps.is_available(): + device = torch.device("mps") # Use MPS for Apple Silicon + else: + device = torch.device("cpu") # Fallback to CPU + + # Move the model to the detected device + embed_model = embed_model.to(device) + + # Tokenize the input texts and move them to the same device encoded_texts = tokenizer( texts, return_tensors="pt", padding=True, truncation=True ).to(device) + + # Perform inference with torch.no_grad(): outputs = embed_model( input_ids=encoded_texts["input_ids"], attention_mask=encoded_texts["attention_mask"], ) embeddings = outputs.last_hidden_state.mean(dim=1) + + # Convert embeddings to NumPy if embeddings.dtype == torch.bfloat16: return embeddings.detach().to(torch.float32).cpu().numpy() else: