Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 669491071
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Aug 31, 2024
1 parent 057c93c commit 16e844e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion swirl_dynamics/lib/diffusion/unets.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def __call__(self, emb: Array, cond: dict[str, Array], is_training: bool):
dtype=self.dtype,
param_dtype=self.param_dtype,
)(value.reshape(b, -1))
value = nn.swish(value)
value = nn.swish(nn.LayerNorm()(value))

# Concatenate the noise and conditional embedding.
emb = jnp.concatenate([emb, value], axis=-1)
Expand Down

0 comments on commit 16e844e

Please sign in to comment.