diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py index fb5208b0..954a99b7 100644 --- a/lightrag/llm/hf.py +++ b/lightrag/llm/hf.py @@ -138,16 +138,31 @@ async def hf_model_complete( async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: - device = next(embed_model.parameters()).device + # Detect the appropriate device + if torch.cuda.is_available(): + device = next(embed_model.parameters()).device # Use CUDA if available + elif torch.backends.mps.is_available(): + device = torch.device("mps") # Use MPS for Apple Silicon + else: + device = torch.device("cpu") # Fallback to CPU + + # Move the model to the detected device + embed_model = embed_model.to(device) + + # Tokenize the input texts and move them to the same device encoded_texts = tokenizer( texts, return_tensors="pt", padding=True, truncation=True ).to(device) + + # Perform inference with torch.no_grad(): outputs = embed_model( input_ids=encoded_texts["input_ids"], attention_mask=encoded_texts["attention_mask"], ) embeddings = outputs.last_hidden_state.mean(dim=1) + + # Convert embeddings to NumPy if embeddings.dtype == torch.bfloat16: return embeddings.detach().to(torch.float32).cpu().numpy() else: