diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..327ca34ad 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -5,6 +5,7 @@ from absl import logging import jax +import sacrebleu import tensorflow as tf import torch import torch.distributed as dist @@ -162,7 +163,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = bleu.corpus_bleu(predictions, [references]).score + bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/setup.cfg b/setup.cfg index 0aa4dce49..23e86a13b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -104,7 +104,6 @@ wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 sacrebleu==2.4.3 - # Frameworks # # JAX Core