From 59dfe11e569a3d54050144baf6ae327f6ccddf1e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 7 Dec 2023 04:02:48 +0000 Subject: [PATCH] debugging --- .../workloads/fastmri/fastmri_jax/workload.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py index a8ad7db94..155b2dff1 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py @@ -5,6 +5,7 @@ from typing import Dict, Optional, Tuple from flax import jax_utils +import flax.linen as nn import jax import jax.numpy as jnp @@ -35,6 +36,14 @@ def init_model_fn( dropout_rate=dropout_rate) variables = jax.jit(self._model.init)({'params': rng}, fake_batch) params = variables['params'] + tabulate_fn = nn.tabulate( + self._model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, 'force_jupyter':False, 'width':240 + } + ) + print(tabulate_fn(fake_batch, train=False)) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) params = jax_utils.replicate(params)