From 5eac985fcefc7fa0f93c2e4f28e0d71ca6db7d3d Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 7 Dec 2024 21:07:21 +0530 Subject: [PATCH] fix: going back to sacrebleu v1.3.1 --- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 5 ++--- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 5 ++--- setup.cfg | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 72108c9d9..046d5e469 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -12,11 +12,10 @@ import jax.numpy as jnp import numpy as np import optax -import sacrebleu from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -#from algorithmic_efficiency.workloads.wmt import bleu +from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_jax import decode from algorithmic_efficiency.workloads.wmt.wmt_jax import models from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload @@ -204,7 +203,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score + bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index b554b2ab3..0ba49c2f6 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -5,7 +5,6 @@ from absl import logging import jax -import sacrebleu import tensorflow as tf import torch import torch.distributed as dist @@ -16,7 +15,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec -#from algorithmic_efficiency.workloads.wmt import bleu +from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload @@ -163,7 +162,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score + bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/setup.cfg b/setup.cfg index e8044fe02..a7c224407 100644 --- a/setup.cfg +++ b/setup.cfg @@ -103,7 +103,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 - sacrebleu==2.4.3 + sacrebleu==1.3.1 # Frameworks # # JAX Core