Skip to content

Commit

Permalink
fix sdpa musicgen
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe committed Jun 3, 2024
1 parent d475f76 commit c84cab4
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,20 @@ def forward(
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
# Ignore copy
if attention_mask is not None and (attention_mask.mean(dim=[1,2,3]) <= torch.finfo(attention_mask.dtype).min).any():
logger.warning_once(
'`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
'Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
Expand Down

0 comments on commit c84cab4

Please sign in to comment.