Skip to content

Commit

Permalink
Check output_attentions is False in BetterTransformer (#1306)
Browse files Browse the repository at this point in the history
add checks on output_attentions
  • Loading branch information
fxmarty authored Aug 22, 2023
1 parent 0323172 commit 2c1eaf6
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __init__(self, bert_layer, config):
self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_):
# No check on output_attentions here as roformer relies on BertLayerBetterTransformer but does not pass output_attentions as keyword argument.
if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
if hidden_states.is_nested:
attention_mask = None
Expand Down Expand Up @@ -463,7 +464,10 @@ def __init__(self, bart_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
def forward(self, hidden_states, attention_mask, output_attentions: bool, position_bias=None, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
Expand Down Expand Up @@ -655,7 +659,10 @@ def __init__(self, mbart_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
def forward(self, hidden_states, attention_mask, output_attentions: bool, position_bias=None, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
Expand Down Expand Up @@ -842,7 +849,10 @@ def __init__(self, bert_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, attn_mask, head_mask=None, output_attentions=None, *_):
def forward(self, hidden_states, attn_mask, output_attentions: bool, head_mask=None, *_):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
if hidden_states.is_nested:
attn_mask = None
Expand Down Expand Up @@ -1019,7 +1029,10 @@ def __init__(self, whisper_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_, **__):
def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
attention_mask = None # attention mask seems to be always None: https://github.com/huggingface/transformers/blob/94b3f544a1f5e04b78d87a2ae32a7ac252e22e31/src/transformers/models/whisper/modeling_whisper.py#L690

Expand Down Expand Up @@ -1139,7 +1152,10 @@ def __init__(self, vit_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, *_, **__):
def forward(self, hidden_states, output_attentions: bool, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
attention_mask = None

Expand Down Expand Up @@ -1259,7 +1275,10 @@ def __init__(self, vilt_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, *_, **__):
def forward(self, hidden_states, layer_head_mask, output_attentions: bool, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
attention_mask = None

Expand Down Expand Up @@ -1375,7 +1394,10 @@ def __init__(self, wav2vec2_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, **__):
def forward(self, hidden_states, attention_mask, output_attentions: bool, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
if hidden_states.is_nested:
attention_mask = None
Expand Down Expand Up @@ -1497,7 +1519,10 @@ def __init__(self, fsmt_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
def forward(self, hidden_states, attention_mask, output_attentions: bool, position_bias=None, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
Expand Down Expand Up @@ -1638,7 +1663,10 @@ def __init__(self, prophetnet_layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_, **__):
def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
Expand Down Expand Up @@ -1772,10 +1800,13 @@ def __init__(self, layer, config):

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_, **__):
def forward(self, hidden_states, attention_mask, causal_attention_mask, output_attentions: bool, *_, **__):
if output_attentions:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if not self.training and not torch.is_autocast_enabled() and not torch.is_autocast_cpu_enabled():
# we expect attention_mask to be None in the vision model
if attention_mask is not None:
if attention_mask is not None or causal_attention_mask is not None:
raise ValueError(
"Please do not use attention masks when using `BetterTransformer` converted vision models"
)
Expand Down

0 comments on commit 2c1eaf6

Please sign in to comment.