diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index f19191875..c9c6ccc7e 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -62,7 +62,7 @@ def __init__(self, ch //= 2 self.up_transpose_conv.append( - TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)) + TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm, size)) size = int(size * 2) self.up_conv.append( nn.Sequential(