diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index 38bf73892..cd73709c5 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -173,7 +173,7 @@ def __init__(self, super().__init__() if use_layer_norm: size = int(size) - norm_layer = nn.GroupNorm(num_groups=1, num_channels=out_chans, eps=1e-6) + norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) else: norm_layer = nn.InstanceNorm2d if use_tanh: @@ -183,7 +183,6 @@ def __init__(self, self.layers = nn.Sequential( nn.ConvTranspose2d( in_chans, out_chans, kernel_size=2, stride=2, bias=False), - nn.GroupNorm(num_groups=1, num_channels=out_chans, eps=1e-6), norm_layer(out_chans), activation_fn, )