From 86029a742094a653e5bf9a6f17f0d42c0990671d Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:10:24 +0530 Subject: [PATCH] fix: explicitly using mask kwarg to use MultiHeadDotProductAttention and also using sacrebleu --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 1 + algorithmic_efficiency/workloads/wmt/wmt_jax/models.py | 6 +++--- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index d0bbecb8f..af1b763c1 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,6 +8,7 @@ import math import tensorflow as tf + #from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py index e4b5cd014..7bbc0b168 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py @@ -224,7 +224,7 @@ def __call__(self, inputs, encoder_mask=None): dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic)(cfg.attention_temp * x, x, - encoder_mask) + mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +288,7 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, decoder_mask) + decode=cfg.decode)(cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -311,7 +311,7 @@ def __call__(self, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic)(cfg.attention_temp * y, encoded, - encoder_decoder_mask) + mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 046d5e469..442c85899 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -12,6 +12,7 @@ import jax.numpy as jnp import numpy as np import optax +import sacrebleu from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -203,7 +204,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = bleu.corpus_bleu(predictions, [references]).score + bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn(