From 34392048feea27887de06012f6e5053ec198e452 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 7 Dec 2023 06:24:35 +0000 Subject: [PATCH] fix --- .../workloads/fastmri/fastmri_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py index 24bb233b9..b21e7d427 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py @@ -41,7 +41,7 @@ def init_model_fn( console_kwargs={ 'force_terminal': False, 'force_jupyter': False, 'width': 240}, ) - print(tabulate_fn(fake_inputs, train=False)) + print(tabulate_fn(fake_batch, train=False)) variables = jax.jit(self._model.init)({'params': rng}, fake_batch) params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params)