Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 7, 2023
1 parent 49266e3 commit 8879598
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __call__(self, x, train=True):
strides=(1, 1),
use_bias=False)(x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
else:
# DO NOT SUBMIT check that this comment edit is correct
# InstanceNorm2d was run with no learnable params in reference code
Expand All @@ -176,7 +176,7 @@ def __call__(self, x, train=True):
strides=(1, 1),
use_bias=False)(x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
else:
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
Expand Down Expand Up @@ -205,7 +205,7 @@ def __call__(self, x):
self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
x)
if self.use_layer_norm:
x = nn.LayerNorm()(x)
x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
else:
x = _instance_norm2d(x, (1, 2))
if self.use_tanh:
Expand Down

0 comments on commit 8879598

Please sign in to comment.