diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 0787403a9..6c23ac7bb 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -297,7 +297,7 @@ def forward(self, inputs, key_padding_mask=None): attn_mask=~key_padding_mask[:, None, None], dropout_p=self.dropout, ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) - out = out * attention_temperature + out = out * self.attention_temperature out = self.out_proj(out) return out