From 16e844e7d78f5a1d7934e22cc56cbdd2e3b2740e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leonardo=20Zepeda-N=C3=BA=C3=B1ez?= Date: Fri, 30 Aug 2024 17:06:23 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 669491071 --- swirl_dynamics/lib/diffusion/unets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swirl_dynamics/lib/diffusion/unets.py b/swirl_dynamics/lib/diffusion/unets.py index e530590..26439c4 100644 --- a/swirl_dynamics/lib/diffusion/unets.py +++ b/swirl_dynamics/lib/diffusion/unets.py @@ -573,7 +573,7 @@ def __call__(self, emb: Array, cond: dict[str, Array], is_training: bool): dtype=self.dtype, param_dtype=self.param_dtype, )(value.reshape(b, -1)) - value = nn.swish(value) + value = nn.swish(nn.LayerNorm()(value)) # Concatenate the noise and conditional embedding. emb = jnp.concatenate([emb, value], axis=-1)