diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 48179f0e0a4cad..6d68405ab35a24 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -48,6 +48,7 @@ is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, is_torch_mlu_available, + is_torch_mps_available, is_torch_neuroncore_available, is_torch_npu_available, is_torch_tf32_available, @@ -2178,6 +2179,8 @@ def _setup_devices(self) -> "torch.device": ) if self.use_cpu: device = torch.device("cpu") + elif is_torch_mps_available(): + device = torch.device("mps") elif is_torch_xpu_available(): if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"): raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`")