From ed112133aa3784af1f60ee99357e87477b422efa Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 7 Dec 2023 06:34:41 +0000 Subject: [PATCH] fix --- .../fastmri/fastmri_pytorch/models.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index 0b8ab014f..e8c5a80a7 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -14,6 +14,20 @@ from algorithmic_efficiency import init_utils +class LayerNorm(nn.Module): + + def __init__(self, dim, epsilon=1e-6): + super().__init__() + self.dim = dim + + self.scale = nn.Parameter(torch.zeros(self.dim)) + self.bias = nn.Parameter(torch.zeros(self.dim)) + self.epsilon = epsilon + + def forward(self, x): + return F.layer_norm(x, (self.dim,), 1 + self.scale, self.bias, self.epsilon) + + class UNet(nn.Module): r"""U-Net model from `"U-net: Convolutional networks @@ -122,20 +136,20 @@ def __init__(self, if use_layer_norm: size = int(size) - norm_layer = nn.LayerNorm + norm_layer = LayerNorm else: - norm_layer = nn.InstanceNorm2d(out_chans) + norm_layer = nn.InstanceNorm2d if use_tanh: activation_fn = nn.Tanh() else: activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.conv_layers = nn.Sequential( nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer([out_chans, size, size], eps=1e-06), + norm_layer(out_chans), activation_fn, nn.Dropout2d(dropout_rate), nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), - norm_layer([out_chans, size, size], eps=1e-06), + norm_layer(out_chans), activation_fn, nn.Dropout2d(dropout_rate), ) @@ -158,9 +172,9 @@ def __init__(self, super().__init__() if use_layer_norm: size = int(size) - norm_layer = nn.LayerNorm([out_chans, size, size], eps=1e-06) + norm_layer = LayerNorm else: - norm_layer = nn.InstanceNorm2d(out_chans) + norm_layer = nn.InstanceNorm2d if use_tanh: activation_fn = nn.Tanh() else: @@ -168,7 +182,7 @@ def __init__(self, self.layers = nn.Sequential( nn.ConvTranspose2d( in_chans, out_chans, kernel_size=2, stride=2, bias=False), - norm_layer, + norm_layer(out_chans), activation_fn, )