From 862f5009b571d64ec3a808e05689be7b079cbbc3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 26 Sep 2023 22:55:51 +0000 Subject: [PATCH] add tabulate for deepspeech debugging --- .../librispeech_deepspeech/librispeech_jax/workload.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4086a5841..54da37f0e 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -5,6 +5,8 @@ import jax import jax.numpy as jnp import numpy as np +import flax.linen as nn +from absl import logging from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -34,6 +36,11 @@ def init_model_fn( input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] + tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), + console_kwargs={'force_terminal': False, + 'force_jupyter': False, + 'width': 240}) + logging.info(tabuleate_fn(*fake_input_batch), train=False) model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, dropout_rng = jax.random.split(rng, 2)