From 1f9517830e91c2234c2fc1a9b163ca353a8347f3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 7 Dec 2023 07:34:24 +0000 Subject: [PATCH] fix --- .../workloads/fastmri/fastmri_pytorch/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index 6694d83ad..555af5fa5 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -136,8 +136,8 @@ def __init__(self, if use_layer_norm: size = int(size) - norm_layer = LayerNorm - normalized_shape = (out_chans, size, size) + norm_layer = nn.GroupNorm + normalized_shape = (1, out_chans) else: norm_layer = nn.InstanceNorm2d normalized_shape = out_chans @@ -174,8 +174,8 @@ def __init__(self, super().__init__() if use_layer_norm: size = int(size) - norm_layer = LayerNorm - normalized_shape = (out_chans, size, size) + norm_layer = nn.GroupNorm + normalized_shape = (1, out_chans) else: norm_layer = nn.InstanceNorm2d normalized_shape = out_chans