diff --git a/projects/MT5/layers/attention_layer.py b/projects/MT5/layers/attention_layer.py index 166e792bf..3cfe8c41d 100644 --- a/projects/MT5/layers/attention_layer.py +++ b/projects/MT5/layers/attention_layer.py @@ -218,7 +218,7 @@ def forward( ) else: position_bias = self.compute_bias( - real_seq_length, key_length, placement=attention_mask.placement + real_seq_length, key_length, placement=attention_scores.placement ) if past_key_value is not None: @@ -228,13 +228,14 @@ def forward( if use_cache: attention_mask = attention_mask.expand_as(attention_scores) + attention_dropout_prob = self.attention_dropout_prob if self.training else 0.0 attention_weights = flow._C.fused_bias_add_scale_mask_softmax_dropout( attention_scores, position_bias, attention_mask, fill_value=-10000.0, scale=1, - p=self.attention_dropout_prob, + p=attention_dropout_prob, )[0] else: attention_scores = attention_scores + position_bias diff --git a/projects/MT5/mt5_model.py b/projects/MT5/mt5_model.py index a16144ff2..2f50b735f 100644 --- a/projects/MT5/mt5_model.py +++ b/projects/MT5/mt5_model.py @@ -203,7 +203,7 @@ def forward( position_bias = None encoder_decoder_position_bias = None self.set_cache(encoder_states=None, past_key_values=None) - encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask) + encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask) if encoder_attn_mask is not None else encoder_attn_mask enc_embedding_output = self.embedding(encoder_input_ids) # transpose [batch_size, seq_len, embed_size] to [seq_len, batch_size, embed_size] enc_hidden_states = enc_embedding_output.transpose(0, 1) @@ -219,10 +219,11 @@ def forward( if only_encoder: return encoder_states - decoder_attn_mask = self.extended_attn_mask( - decoder_attn_mask, decoder_input_ids, is_decoder=True - ) - encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask) + if decoder_attn_mask is not None: + decoder_attn_mask = self.extended_attn_mask( + decoder_attn_mask, decoder_input_ids, is_decoder=True + ) + encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask) if encoder_decoder_attn_mask is not None else encoder_decoder_attn_mask dec_embedding_output = self.embedding(decoder_input_ids) # transpose [batch_size, seq_len, embed_size] to [seq_len, batch_size, embed_size]