diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index 5199e5aa6..6bbd18645 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -52,18 +52,18 @@ def __init__(self, self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() - # size = size * 2 + for _ in range(num_pool_layers - 1): + size = int(size * 2) self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm, size)) - size = int(size * 2) self.up_conv.append( ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm, size)) ch //= 2 + size = int(size * 2) self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm, size)) - size = int(size * 2) self.up_conv.append( nn.Sequential( ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm, size),