diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 54da37f0e..46c524381 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -40,7 +40,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240}) - logging.info(tabuleate_fn(*fake_input_batch), train=False) + logging.info(tabulate_fn(*fake_input_batch), train=False) model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, dropout_rng = jax.random.split(rng, 2)