Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 7, 2023
1 parent 8af0c73 commit 59dfe11
Showing 1 changed file with 9 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, Optional, Tuple

from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -35,6 +36,14 @@ def init_model_fn(
dropout_rate=dropout_rate)
variables = jax.jit(self._model.init)({'params': rng}, fake_batch)
params = variables['params']
tabulate_fn = nn.tabulate(
self._model,
jax.random.PRNGKey(0),
console_kwargs={
'force_terminal': False, 'force_jupyter':False, 'width':240
}
)
print(tabulate_fn(fake_batch, train=False))
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
params = jax_utils.replicate(params)
Expand Down

0 comments on commit 59dfe11

Please sign in to comment.