diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index 57cfa105e4..fff4cab1ba 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -206,6 +206,12 @@ def transform( The converted model if the conversion has been successful. """ + hf_config = model.config + if hf_config.model_type in ["falcon", "gpt_bigcode", "llama", "whisper"]: + raise ValueError( + f"Transformers now supports natively BetterTransformer optimizations (torch.nn.functional.scaled_dot_product_attention) for the model type {hf_config.model_type}. Please upgrade to transformers>=4.36 and torch>=2.1.1 to use it. Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention" + ) + # Check if we have to load the model using `accelerate` if hasattr(model, "hf_device_map"): load_accelerate = True @@ -236,13 +242,6 @@ def transform( f"BetterTransformer requires torch>=2.0 but {torch.__version__} is installed. Please upgrade PyTorch." ) - hf_config = model.config - - if hf_config.model_type in ["falcon", "gpt_bigcode", "llama", "whisper"]: - raise ValueError( - f"Transformers now supports natively BetterTransformer optimizations (torch.nn.functional.scaled_dot_product_attention) for the model type {hf_config.model_type}. Please upgrade to transformers>=4.36 and torch>=2.1.1 to use it. Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention" - ) - if load_accelerate: # Remove the hooks from the original model to avoid weights being on `meta` device. remove_hook_from_module(model, recurse=True) diff --git a/tests/bettertransformer/test_audio.py b/tests/bettertransformer/test_audio.py index 86f2b29693..fb1fccc57c 100644 --- a/tests/bettertransformer/test_audio.py +++ b/tests/bettertransformer/test_audio.py @@ -33,6 +33,16 @@ ] +class TestsWhisper(unittest.TestCase): + def test_error_message(self): + model = AutoModel.from_pretrained("openai/whisper-tiny") + + with self.assertRaises(ValueError) as cm: + model = BetterTransformer.transform(model) + + self.assertTrue("Transformers now supports natively BetterTransformer optimizations" in str(cm.exception)) + + class BetterTransformersBarkTest(BetterTransformersTestMixin, unittest.TestCase): r""" Testing suite for Bark - tests all the tests defined in `BetterTransformersTestMixin`