From 88795984e922e7a1b709fe78fd8fff905f79a62f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 7 Dec 2023 07:04:01 +0000 Subject: [PATCH] fix --- .../workloads/fastmri/fastmri_jax/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py index 2d34acd29..f4d7aaba0 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py @@ -155,7 +155,7 @@ def __call__(self, x, train=True): strides=(1, 1), use_bias=False)(x) if self.use_layer_norm: - x = nn.LayerNorm()(x) + x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) else: # DO NOT SUBMIT check that this comment edit is correct # InstanceNorm2d was run with no learnable params in reference code @@ -176,7 +176,7 @@ def __call__(self, x, train=True): strides=(1, 1), use_bias=False)(x) if self.use_layer_norm: - x = nn.LayerNorm()(x) + x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) else: x = _instance_norm2d(x, (1, 2)) x = activation_fn(x) @@ -205,7 +205,7 @@ def __call__(self, x): self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)( x) if self.use_layer_norm: - x = nn.LayerNorm()(x) + x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x) else: x = _instance_norm2d(x, (1, 2)) if self.use_tanh: