Merge pull request #1139 from mvignieri/fix-conda-apple-silicon

fix hf_embed torch device use MPS or CPU when CUDA is not available -…
This commit is contained in:
zrguo
2025-03-21 13:24:00 +08:00
committed by GitHub

View File

@@ -138,16 +138,31 @@ async def hf_model_complete(
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device # Detect the appropriate device
if torch.cuda.is_available():
device = next(embed_model.parameters()).device # Use CUDA if available
elif torch.backends.mps.is_available():
device = torch.device("mps") # Use MPS for Apple Silicon
else:
device = torch.device("cpu") # Fallback to CPU
# Move the model to the detected device
embed_model = embed_model.to(device)
# Tokenize the input texts and move them to the same device
encoded_texts = tokenizer( encoded_texts = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True texts, return_tensors="pt", padding=True, truncation=True
).to(device) ).to(device)
# Perform inference
with torch.no_grad(): with torch.no_grad():
outputs = embed_model( outputs = embed_model(
input_ids=encoded_texts["input_ids"], input_ids=encoded_texts["input_ids"],
attention_mask=encoded_texts["attention_mask"], attention_mask=encoded_texts["attention_mask"],
) )
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
# Convert embeddings to NumPy
if embeddings.dtype == torch.bfloat16: if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy() return embeddings.detach().to(torch.float32).cpu().numpy()
else: else: