diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py index 8d15a0f43..44bff0e21 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py @@ -126,7 +126,9 @@ def __call__(self, x, train=True): output = jnp.concatenate((output, downsample_layer), axis=-1) output = conv(output, train) - output = nn.Conv(self.out_channels, kernel_size=(1, 1), strides=(1, 1))(output) + output = nn.Conv( + self.out_channels, kernel_size=(1, 1), strides=(1, 1))( + output) return output.squeeze(-1)