From 515f16b08a5a910476334471392b5e25cf91418f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 7 Dec 2023 03:11:35 +0000 Subject: [PATCH] fix --- .../workloads/fastmri/fastmri_pytorch/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index 5199e5aa6..6bbd18645 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -52,18 +52,18 @@ def __init__(self, self.up_conv = nn.ModuleList() self.up_transpose_conv = nn.ModuleList() - # size = size * 2 + for _ in range(num_pool_layers - 1): + size = int(size * 2) self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm, size)) - size = int(size * 2) self.up_conv.append( ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm, size)) ch //= 2 + size = int(size * 2) self.up_transpose_conv.append( TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm, size)) - size = int(size * 2) self.up_conv.append( nn.Sequential( ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm, size),