diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py index 24bb233b9..b21e7d427 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py @@ -41,7 +41,7 @@ def init_model_fn( console_kwargs={ 'force_terminal': False, 'force_jupyter': False, 'width': 240}, ) - print(tabulate_fn(fake_inputs, train=False)) + print(tabulate_fn(fake_batch, train=False)) variables = jax.jit(self._model.init)({'params': rng}, fake_batch) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params)