diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index 3ee0030ff..4cef9808f 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -114,8 +114,7 @@ def __init__(self, super().__init__() if use_layer_norm: - norm_layer = nn.LayerNorm(out_chans) - # norm_layer = nn.InstanceNorm2d(out_chans) + norm_layer = nn.LayerNorm() else: norm_layer = nn.InstanceNorm2d(out_chans) if use_tanh: @@ -149,8 +148,7 @@ def __init__(self, ): super().__init__() if use_layer_norm: - norm_layer = nn.LayerNorm(out_chans) - # norm_layer = nn.InstanceNorm2d(out_chans) + norm_layer = nn.LayerNorm() else: norm_layer = nn.InstanceNorm2d(out_chans) if use_tanh: