Skip to content

Commit

Permalink
fix: explicitly using mask kwarg to use MultiHeadDotProductAttention …
Browse files Browse the repository at this point in the history
…and also using sacrebleu
  • Loading branch information
init-22 committed Dec 2, 2024
1 parent abbdc82 commit 86029a7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import math

import tensorflow as tf

#from tensorflow_addons import image as contrib_image

# This signifies the max integer that the controller RNN could predict for the
Expand Down
6 changes: 3 additions & 3 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __call__(self, inputs, encoder_mask=None):
dropout_rate=attention_dropout_rate,
deterministic=cfg.deterministic)(cfg.attention_temp * x,
x,
encoder_mask)
mask=encoder_mask)

if cfg.dropout_rate is None:
dropout_rate = 0.1
Expand Down Expand Up @@ -288,7 +288,7 @@ def __call__(self,
broadcast_dropout=False,
dropout_rate=attention_dropout_rate,
deterministic=cfg.deterministic,
decode=cfg.decode)(cfg.attention_temp * x, x, decoder_mask)
decode=cfg.decode)(cfg.attention_temp * x, x, mask=decoder_mask)
if cfg.dropout_rate is None:
dropout_rate = 0.1
else:
Expand All @@ -311,7 +311,7 @@ def __call__(self,
dropout_rate=attention_dropout_rate,
deterministic=cfg.deterministic)(cfg.attention_temp * y,
encoded,
encoder_decoder_mask)
mask=encoder_decoder_mask)

y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
y = y + x
Expand Down
3 changes: 2 additions & 1 deletion algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import jax.numpy as jnp
import numpy as np
import optax
import sacrebleu

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
Expand Down Expand Up @@ -203,7 +204,7 @@ def translate_and_calculate_bleu(self,
predictions.append(self._decode_tokens(predicted[idx]))

# Calculate BLEU score for translated eval corpus against reference.
bleu_score = bleu.corpus_bleu(predictions, [references]).score
bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score
return bleu_score

def init_model_fn(
Expand Down

0 comments on commit 86029a7

Please sign in to comment.