From 027574a39f6fd37a1edaac16ac6534f77a2bdd57 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 6 Dec 2023 01:16:08 +0000 Subject: [PATCH] debugging --- .../workloads/fastmri/fastmri_pytorch/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index 4cef9808f..0b8b3b193 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -114,7 +114,7 @@ def __init__(self, super().__init__() if use_layer_norm: - norm_layer = nn.LayerNorm() + norm_layer = nn.LayerNorm([out_chans, 320, 320]) else: norm_layer = nn.InstanceNorm2d(out_chans) if use_tanh: @@ -148,7 +148,7 @@ def __init__(self, ): super().__init__() if use_layer_norm: - norm_layer = nn.LayerNorm() + norm_layer = nn.LayerNorm([out_chans, 320, 320) else: norm_layer = nn.InstanceNorm2d(out_chans) if use_tanh: