From d62ff77c71b1e8bd2a612d8e85e135a056bc56b6 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 7 Dec 2023 03:15:31 +0000 Subject: [PATCH] add casting --- .../workloads/fastmri/fastmri_pytorch/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index c635fc595..aa546b601 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -121,6 +121,7 @@ def __init__(self, super().__init__() if use_layer_norm: + size = int(size) norm_layer = nn.LayerNorm([out_chans, size, size], eps=1e-06) else: norm_layer = nn.InstanceNorm2d(out_chans) @@ -156,6 +157,7 @@ def __init__(self, ): super().__init__() if use_layer_norm: + size = int(size) norm_layer = nn.LayerNorm([out_chans, size, size], eps=1e-06) else: norm_layer = nn.InstanceNorm2d(out_chans)