From 879cd2d9029eeae939c2d6085505b7bb211bd103 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 7 Dec 2023 08:05:01 +0000 Subject: [PATCH] fix --- .../workloads/fastmri/fastmri_pytorch/models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index 38bf73892..cd73709c5 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -173,7 +173,7 @@ def __init__(self, super().__init__() if use_layer_norm: size = int(size) - norm_layer = nn.GroupNorm(num_groups=1, num_channels=out_chans, eps=1e-6) + norm_layer = partial(nn.GroupNorm, 1, eps=1e-6) else: norm_layer = nn.InstanceNorm2d if use_tanh: @@ -183,7 +183,6 @@ def __init__(self, self.layers = nn.Sequential( nn.ConvTranspose2d( in_chans, out_chans, kernel_size=2, stride=2, bias=False), - nn.GroupNorm(num_groups=1, num_channels=out_chans, eps=1e-6), norm_layer(out_chans), activation_fn, )