diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py index 155b2dff1..f34c854e9 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py @@ -36,14 +36,6 @@ def init_model_fn( dropout_rate=dropout_rate) variables = jax.jit(self._model.init)({'params': rng}, fake_batch) params = variables['params'] - tabulate_fn = nn.tabulate( - self._model, - jax.random.PRNGKey(0), - console_kwargs={ - 'force_terminal': False, 'force_jupyter':False, 'width':240 - } - ) - print(tabulate_fn(fake_batch, train=False)) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) params = jax_utils.replicate(params) diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 30939b7d4..f8e073d8b 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -28,9 +28,9 @@ def sort_key(k): keys = sorted(sd.keys(), key=sort_key) c = 0 - jax_weights_name = 'kernel' for idx, k in enumerate(keys): new_key = [] + jax_weights_name = 'kernel' for idx2, i in enumerate(k): if 'ModuleList' in i or 'Sequential' in i: continue @@ -44,6 +44,7 @@ def sort_key(k): i = i.replace('ConvTranspose2d', 'ConvTranspose') if 'LayerNorm' in i: jax_weights_name = 'scale' + continue if 'weight' in i: i = i.replace('weight', jax_weights_name) new_key.append(i)