Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 7, 2023
1 parent 3439204 commit ed11213
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@
from algorithmic_efficiency import init_utils


class LayerNorm(nn.Module):

def __init__(self, dim, epsilon=1e-6):
super().__init__()
self.dim = dim

self.scale = nn.Parameter(torch.zeros(self.dim))
self.bias = nn.Parameter(torch.zeros(self.dim))
self.epsilon = epsilon

def forward(self, x):
return F.layer_norm(x, (self.dim,), 1 + self.scale, self.bias, self.epsilon)


class UNet(nn.Module):
r"""U-Net model from
`"U-net: Convolutional networks
Expand Down Expand Up @@ -122,20 +136,20 @@ def __init__(self,

if use_layer_norm:
size = int(size)
norm_layer = nn.LayerNorm
norm_layer = LayerNorm
else:
norm_layer = nn.InstanceNorm2d(out_chans)
norm_layer = nn.InstanceNorm2d
if use_tanh:
activation_fn = nn.Tanh()
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv_layers = nn.Sequential(
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer([out_chans, size, size], eps=1e-06),
norm_layer(out_chans),
activation_fn,
nn.Dropout2d(dropout_rate),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer([out_chans, size, size], eps=1e-06),
norm_layer(out_chans),
activation_fn,
nn.Dropout2d(dropout_rate),
)
Expand All @@ -158,17 +172,17 @@ def __init__(self,
super().__init__()
if use_layer_norm:
size = int(size)
norm_layer = nn.LayerNorm([out_chans, size, size], eps=1e-06)
norm_layer = LayerNorm
else:
norm_layer = nn.InstanceNorm2d(out_chans)
norm_layer = nn.InstanceNorm2d
if use_tanh:
activation_fn = nn.Tanh()
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layers = nn.Sequential(
nn.ConvTranspose2d(
in_chans, out_chans, kernel_size=2, stride=2, bias=False),
norm_layer,
norm_layer(out_chans),
activation_fn,
)

Expand Down

0 comments on commit ed11213

Please sign in to comment.