Skip to content

Commit

Permalink
add tabulate for deepspeech debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Sep 26, 2023
1 parent 7970388 commit 862f500
Showing 1 changed file with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
from absl import logging

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
Expand Down Expand Up @@ -34,6 +36,11 @@ def init_model_fn(
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]

tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0),
console_kwargs={'force_terminal': False,
'force_jupyter': False,
'width': 240})
logging.info(tabuleate_fn(*fake_input_batch), train=False)
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))

params_rng, dropout_rng = jax.random.split(rng, 2)
Expand Down

0 comments on commit 862f500

Please sign in to comment.