diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 1d2c48dbc..90ae7205a 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -40,7 +40,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240}) - logging.info(tabulate_fn(*fake_input_batch, train=False)) + # logging.info(tabulate_fn(*fake_input_batch, train=False)) model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, dropout_rng = jax.random.split(rng, 2)