diff --git a/modules/faster_whisper_inference.py b/modules/faster_whisper_inference.py index 88b4961b..c20c3282 100644 --- a/modules/faster_whisper_inference.py +++ b/modules/faster_whisper_inference.py @@ -31,7 +31,12 @@ def __init__(self): self.available_models = self.model_paths.keys() self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values())) self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"] - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + self.device = "cuda" + elif torch.backends.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" self.available_compute_types = ctranslate2.get_supported_compute_types( "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu") self.current_compute_type = "float16" if self.device == "cuda" else "float32" diff --git a/modules/whisper_Inference.py b/modules/whisper_Inference.py index 65782f9c..736fa30b 100644 --- a/modules/whisper_Inference.py +++ b/modules/whisper_Inference.py @@ -23,6 +23,12 @@ def __init__(self): self.available_models = whisper.available_models() self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values())) self.translatable_model = ["large", "large-v1", "large-v2", "large-v3"] + if torch.cuda.is_available(): + self.device = "cuda" + elif torch.backends.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.available_compute_types = ["float16", "float32"] self.current_compute_type = "float16" if self.device == "cuda" else "float32"