diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index aa546b601..0b8ab014f 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -122,7 +122,7 @@ def __init__(self, if use_layer_norm: size = int(size) - norm_layer = nn.LayerNorm([out_chans, size, size], eps=1e-06) + norm_layer = nn.LayerNorm else: norm_layer = nn.InstanceNorm2d(out_chans) if use_tanh: @@ -131,11 +131,11 @@ def __init__(self, activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.conv_layers = nn.Sequential( nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer, + norm_layer([out_chans, size, size], eps=1e-06), activation_fn, nn.Dropout2d(dropout_rate), nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer, + norm_layer([out_chans, size, size], eps=1e-06), activation_fn, nn.Dropout2d(dropout_rate), )