From 1f902a52bd74f6d88588338f56c47a8df3d4b626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leonardo=20Zepeda-N=C3=BA=C3=B1ez?= Date: Thu, 11 Apr 2024 13:41:30 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 623926391 --- swirl_dynamics/lib/diffusion/unets.py | 167 +++++++++++++++++-- swirl_dynamics/lib/diffusion/unets3d.py | 55 +++++- swirl_dynamics/lib/layers/axial_attention.py | 11 ++ swirl_dynamics/lib/layers/convolutions.py | 45 ++++- swirl_dynamics/lib/layers/residual.py | 14 +- swirl_dynamics/lib/layers/resize.py | 16 +- 6 files changed, 292 insertions(+), 16 deletions(-) diff --git a/swirl_dynamics/lib/diffusion/unets.py b/swirl_dynamics/lib/diffusion/unets.py index afa8819..9bd443b 100644 --- a/swirl_dynamics/lib/diffusion/unets.py +++ b/swirl_dynamics/lib/diffusion/unets.py @@ -24,6 +24,13 @@ Array = jax.Array Initializer = nn.initializers.Initializer +PrecisionLike = ( + None + | str + | jax.lax.Precision + | tuple[str, str] + | tuple[jax.lax.Precision, jax.lax.Precision] +) def default_init(scale: float = 1e-10) -> Initializer: @@ -44,6 +51,9 @@ class AdaptiveScale(nn.Module): more general FiLM technique see https://arxiv.org/abs/1709.07871. """ act_fun: Callable[[Array], Array] = nn.swish + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array, emb: Array) -> Array: @@ -60,7 +70,13 @@ def __call__(self, x: Array, emb: Array) -> Array: "The dimension of the embedding needs to be two, instead it was : " + str(emb.ndim) ) - affine = nn.Dense(features=x.shape[-1] * 2, kernel_init=default_init(1.0)) + affine = nn.Dense( + features=x.shape[-1] * 2, + kernel_init=default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + ) scale_params = affine(self.act_fun(emb)) # Unsqueeze in the middle to allow broadcasting. scale_params = scale_params.reshape( @@ -74,7 +90,9 @@ class AttentionBlock(nn.Module): """Attention block.""" num_heads: int = 1 + precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array, is_training: bool) -> Array: @@ -84,6 +102,8 @@ def __call__(self, x: Array, is_training: bool) -> Array: kernel_init=nn.initializers.xavier_uniform(), deterministic=not is_training, dtype=self.dtype, + precision=self.precision, + param_dtype=self.param_dtype, name="dot_attn", )(h, h) return layers.CombineResidualWithSkip()(residual=h, skip=x) @@ -95,6 +115,9 @@ class ResConv1x(nn.Module): hidden_layer_size: int out_channels: int act_fun: Callable[[Array], Array] = nn.swish + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array) -> Array: @@ -104,14 +127,24 @@ def __call__(self, x: Array) -> Array: features=self.hidden_layer_size, kernel_size=kernel_size, kernel_init=default_init(1.0), + dtype=self.dtype, + precision=self.precision, + param_dtype=self.param_dtype, )(x) x = self.act_fun(x) x = nn.Conv( features=self.out_channels, kernel_size=kernel_size, kernel_init=default_init(1.0), + dtype=self.dtype, + precision=self.precision, + param_dtype=self.param_dtype, )(x) - return layers.CombineResidualWithSkip()(residual=x, skip=skip) + return layers.CombineResidualWithSkip( + dtype=self.dtype, + precision=self.precision, + param_dtype=self.param_dtype, + )(residual=x, skip=skip) class ConvBlock(nn.Module): @@ -138,6 +171,9 @@ class ConvBlock(nn.Module): dropout: float = 0.0 film_act_fun: Callable[[Array], Array] = nn.swish act_fun: Callable[[Array], Array] = nn.swish + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array, emb: Array, is_training: bool) -> Array: @@ -149,6 +185,9 @@ def __call__(self, x: Array, emb: Array, is_training: bool) -> Array: kernel_size=self.kernel_size, padding=self.padding, kernel_init=default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name="conv_0", )(h) h = nn.GroupNorm(min(h.shape[-1] // 4, 32))(h) @@ -160,9 +199,17 @@ def __call__(self, x: Array, emb: Array, is_training: bool) -> Array: kernel_size=self.kernel_size, padding=self.padding, kernel_init=default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name="conv_1", )(h) - return layers.CombineResidualWithSkip(project_skip=True)(residual=h, skip=x) + return layers.CombineResidualWithSkip( + project_skip=True, + dtype=self.dtype, + precision=self.precision, + param_dtype=self.param_dtype, + )(residual=h, skip=x) class FourierEmbedding(nn.Module): @@ -172,6 +219,9 @@ class FourierEmbedding(nn.Module): max_freq: float = 2e4 projection: bool = True act_fun: Callable[[Array], Array] = nn.swish + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array) -> Array: @@ -181,9 +231,19 @@ def __call__(self, x: Array) -> Array: x = jnp.concatenate([jnp.sin(x), jnp.cos(x)], axis=-1) if self.projection: - x = nn.Dense(features=2 * self.dims)(x) + x = nn.Dense( + features=2 * self.dims, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(x) x = self.act_fun(x) - x = nn.Dense(features=self.dims)(x) + x = nn.Dense( + features=self.dims, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(x) return x @@ -244,6 +304,9 @@ class Axial2DMLP(nn.Module): out_dims: tuple[int, int] act_fn: Callable[[Array], Array] = nn.swish + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array) -> Array: @@ -253,11 +316,21 @@ def __call__(self, x: Array) -> Array: for i, out_dim in enumerate(self.out_dims): spatial_dim = -3 + i x = jnp.swapaxes(x, spatial_dim, -1) - x = nn.Dense(features=out_dim)(x) + x = nn.Dense( + features=out_dim, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(x) x = nn.swish(x) x = jnp.swapaxes(x, -1, spatial_dim) - x = nn.Dense(features=num_channels)(x) + x = nn.Dense( + features=num_channels, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, + )(x) return x @@ -275,6 +348,9 @@ class MergeChannelCond(nn.Module): kernel_size: Sequence[int] resize_method: str = "cubic" padding: str = "CIRCULAR" + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 class InterpConvMerge(MergeChannelCond): @@ -315,14 +391,19 @@ def __call__(self, x: Array, cond: dict[str, Array]): kernel_size=self.kernel_size, method=self.resize_method, padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"resize_{key}", )(value) - value = nn.swish(nn.LayerNorm()(value)) value = layers.ConvLayer( features=self.embed_dim, kernel_size=self.kernel_size, padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"conv2d_embed_{key}", )(value) @@ -362,6 +443,9 @@ def __call__(self, x: Array, cond: dict[str, Array]): kernel_size=self.kernel_size, resize_method=self.resize_method, padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, ) return merge_channel_cond(x, proc_cond) @@ -383,7 +467,9 @@ class DStack(nn.Module): num_heads: int = 8 channels_per_head: int = -1 use_position_encoding: bool = False + precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array, emb: Array, *, is_training: bool) -> list[Array]: @@ -400,6 +486,9 @@ def __call__(self, x: Array, emb: Array, *, is_training: bool) -> list[Array]: kernel_size=kernel_dim * (3,), padding=self.padding, kernel_init=default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name="conv_in", )(x) skips.append(h) @@ -409,6 +498,9 @@ def __call__(self, x: Array, emb: Array, *, is_training: bool) -> list[Array]: features=channel, ratios=(self.downsample_ratio[level],) * kernel_dim, kernel_init=default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"res{'x'.join(res.astype(str))}.downsample_conv", )(h) res = res // self.downsample_ratio[level] @@ -418,6 +510,9 @@ def __call__(self, x: Array, emb: Array, *, is_training: bool) -> list[Array]: kernel_size=kernel_dim * (3,), padding=self.padding, dropout=self.dropout_rate, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"res{'x'.join(res.astype(str))}.down.block{block_id}", )(h, emb, is_training=is_training) @@ -431,12 +526,17 @@ def __call__(self, x: Array, emb: Array, *, is_training: bool) -> list[Array]: )(h) h = AttentionBlock( num_heads=self.num_heads, + precision=self.precision, dtype=self.dtype, + param_dtype=self.param_dtype, name=f"res{'x'.join(res.astype(str))}.down.block{block_id}.attn", )(h.reshape(b, -1, c), is_training=is_training) h = ResConv1x( hidden_layer_size=channel * 2, out_channels=channel, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"res{'x'.join(res.astype(str))}.down.block{block_id}.res_conv_1x", )(h).reshape(b, *hw, c) skips.append(h) @@ -472,7 +572,9 @@ class UStack(nn.Module): use_attention: bool = False num_heads: int = 8 channels_per_head: int = -1 + precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -489,25 +591,36 @@ def __call__( for level, channel in enumerate(self.num_channels): for block_id in range(self.num_res_blocks[level]): h = layers.CombineResidualWithSkip( - project_skip=h.shape[-1] != skips[-1].shape[-1] + project_skip=h.shape[-1] != skips[-1].shape[-1], + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(residual=h, skip=skips.pop()) h = ConvBlock( out_channels=channel, kernel_size=kernel_dim * (3,), padding=self.padding, dropout=self.dropout_rate, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"res{'x'.join(res.astype(str))}.up.block{block_id}", )(h, emb, is_training=is_training) if self.use_attention and level == 0: # opposite to DStack b, *hw, c = h.shape h = AttentionBlock( num_heads=self.num_heads, + precision=self.precision, dtype=self.dtype, + param_dtype=self.param_dtype, name=f"res{'x'.join(res.astype(str))}.up.block{block_id}.attn", )(h.reshape(b, -1, c), is_training=is_training) h = ResConv1x( hidden_layer_size=channel * 2, out_channels=channel, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"res{'x'.join(res.astype(str))}.up.block{block_id}.res_conv_1x", )(h).reshape(b, *hw, c) @@ -518,19 +631,28 @@ def __call__( kernel_size=kernel_dim * (3,), padding=self.padding, kernel_init=default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"res{'x'.join(res.astype(str))}.conv_upsample", )(h) h = layers.channel_to_space(h, block_shape=kernel_dim * (up_ratio,)) res = res * up_ratio h = layers.CombineResidualWithSkip( - project_skip=h.shape[-1] != skips[-1].shape[-1] + project_skip=h.shape[-1] != skips[-1].shape[-1], + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(residual=h, skip=skips.pop()) h = layers.ConvLayer( features=128, kernel_size=kernel_dim * (3,), padding=self.padding, kernel_init=default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name="conv_out", )(h) return h @@ -553,6 +675,9 @@ class UNet(nn.Module): cond_resize_method: str = "bilinear" cond_embed_dim: int = 128 cond_merging_fn: type[MergeChannelCond] = InterpConvMerge + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -593,6 +718,9 @@ def __call__( output_size=self.resize_to_shape, kernel_size=(7, 7), padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(x) kernel_dim = x.ndim - 2 @@ -602,6 +730,9 @@ def __call__( resize_method=self.cond_resize_method, kernel_size=(3,) * kernel_dim, padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(x, cond) emb = FourierEmbedding(dims=self.noise_embed_dim)(sigma) @@ -614,6 +745,9 @@ def __call__( use_attention=self.use_attention, num_heads=self.num_heads, use_position_encoding=self.use_position_encoding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(x, emb, is_training=is_training) h = UStack( num_channels=self.num_channels[::-1], @@ -623,6 +757,9 @@ def __call__( dropout_rate=self.dropout_rate, use_attention=self.use_attention, num_heads=self.num_heads, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(skips[-1], emb, skips, is_training=is_training) h = nn.swish(nn.GroupNorm(min(h.shape[-1] // 4, 32))(h)) @@ -631,12 +768,20 @@ def __call__( kernel_size=kernel_dim * (3,), padding=self.padding, kernel_init=default_init(), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name="conv_out", )(h) if self.resize_to_shape: h = layers.FilteredResize( - output_size=input_size, kernel_size=(7, 7), padding=self.padding + output_size=input_size, + kernel_size=(7, 7), + padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(h) return h diff --git a/swirl_dynamics/lib/diffusion/unets3d.py b/swirl_dynamics/lib/diffusion/unets3d.py index 9c566c3..0852326 100644 --- a/swirl_dynamics/lib/diffusion/unets3d.py +++ b/swirl_dynamics/lib/diffusion/unets3d.py @@ -30,6 +30,13 @@ from swirl_dynamics.lib.diffusion import unets Array = jax.Array +PrecisionLike = ( + None + | str + | jax.lax.Precision + | tuple[str, str] + | tuple[jax.lax.Precision, jax.lax.Precision] +) def _maybe_broadcast_to_list( @@ -50,7 +57,9 @@ class AxialSelfAttentionBlock(nn.Module): attention_axes: int | Sequence[int] = -2 add_position_embedding: bool | Sequence[bool] = True num_heads: int | Sequence[int] = 1 + precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array, is_training: bool) -> Array: @@ -83,6 +92,8 @@ def __call__(self, x: Array, is_training: bool) -> Array: kernel_init=nn.initializers.xavier_uniform(), deterministic=not is_training, dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, name=f"axial_attn_axis{axis}", )(h) h = nn.GroupNorm( @@ -111,7 +122,9 @@ class DStack(nn.Module): dropout_rate: float = 0.0 num_heads: int = 8 use_position_encoding: bool = False + precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: Array, emb: Array, *, is_training: bool) -> list[Array]: @@ -152,6 +165,9 @@ def __call__(self, x: Array, emb: Array, *, is_training: bool) -> list[Array]: kernel_size=(3, 3), padding=self.padding, dropout=self.dropout_rate, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"{nt}xres{'x'.join(dims_str)}.dblock{block_id}", )(h, emb, is_training=is_training) if ( @@ -168,6 +184,9 @@ def __call__(self, x: Array, emb: Array, *, is_training: bool) -> list[Array]: attention_axes=attn_axes, add_position_embedding=self.use_position_encoding, num_heads=self.num_heads, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"{nt}xres{'x'.join(dims_str)}.dblock{block_id}.attn", )(h, is_training=is_training) skips.append(h) @@ -193,7 +212,9 @@ class UStack(nn.Module): dropout_rate: float = 0.0 use_position_encoding: bool = False num_heads: int = 8 + precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -222,6 +243,9 @@ def __call__( kernel_size=(3, 3), padding=self.padding, dropout=self.dropout_rate, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"{nt}xres{'x'.join(dims_str)}.ublock{block_id}", )(h, emb, is_training=is_training) if ( @@ -238,6 +262,9 @@ def __call__( attention_axes=attn_axes, add_position_embedding=self.use_position_encoding, num_heads=self.num_heads, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"{nt}xres{'x'.join(dims_str)}.ublock{block_id}.attn", )(h, is_training=is_training) @@ -248,6 +275,9 @@ def __call__( kernel_size=(3, 3), padding=self.padding, kernel_init=unets.default_init(1.0), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name=f"{nt}xres{'x'.join(dims_str)}.conv2d_preupsample", )(h) h = layers.channel_to_space(inputs=h, block_shape=(up_ratio, up_ratio)) @@ -317,6 +347,9 @@ class UNet3d(nn.Module): num_heads: int = 8 cond_resize_method: str = "cubic" cond_embed_dim: int = 128 + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -376,6 +409,9 @@ def __call__( output_size=self.resize_to_shape, kernel_size=(7, 7), padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(x) cond = {} if cond is None else cond @@ -384,6 +420,9 @@ def __call__( resize_method=self.cond_resize_method, kernel_size=(3, 3), padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(x, cond) emb = unets.FourierEmbedding(dims=self.noise_embed_dim)(sigma) @@ -398,6 +437,9 @@ def __call__( use_temporal_attention=use_temporal_attn, num_heads=self.num_heads, use_position_encoding=self.use_position_encoding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(x, emb, is_training=is_training) h = UStack( num_channels=self.num_channels[::-1], @@ -409,6 +451,9 @@ def __call__( use_spatial_attention=use_spatial_attn, use_temporal_attention=use_temporal_attn, num_heads=self.num_heads, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(skips[-1], emb, skips, is_training=is_training) h = nn.swish(nn.GroupNorm(min(h.shape[-1] // 4, 32))(h)) h = layers.ConvLayer( @@ -416,12 +461,20 @@ def __call__( kernel_size=(3, 3), padding=self.padding, kernel_init=unets.default_init(), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, name="conv2d_out", )(h) if self.resize_to_shape is not None: h = layers.FilteredResize( - output_size=input_size, kernel_size=(7, 7), padding=self.padding + output_size=input_size, + kernel_size=(7, 7), + padding=self.padding, + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(h) return h diff --git a/swirl_dynamics/lib/layers/axial_attention.py b/swirl_dynamics/lib/layers/axial_attention.py index 1a6076a..27bb6c8 100644 --- a/swirl_dynamics/lib/layers/axial_attention.py +++ b/swirl_dynamics/lib/layers/axial_attention.py @@ -20,6 +20,13 @@ Array = jax.Array +PrecisionLike = ( + None + | str + | jax.lax.Precision + | tuple[str, str] + | tuple[jax.lax.Precision, jax.lax.Precision] +) class AddAxialPositionEmbedding(nn.Module): @@ -62,7 +69,9 @@ class AxialSelfAttention(nn.Module): attention_axis: int = -2 kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() deterministic: bool = True + precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, inputs: Array) -> Array: @@ -82,6 +91,8 @@ def __call__(self, inputs: Array) -> Array: kernel_init=self.kernel_init, deterministic=self.deterministic, dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, )(inputs_q=inputs_q) out = jnp.reshape(out, inputs.shape) diff --git a/swirl_dynamics/lib/layers/convolutions.py b/swirl_dynamics/lib/layers/convolutions.py index 1d2789e..b9d4b55 100644 --- a/swirl_dynamics/lib/layers/convolutions.py +++ b/swirl_dynamics/lib/layers/convolutions.py @@ -23,6 +23,13 @@ import numpy as np Array = jax.Array +PrecisionLike = ( + None + | str + | jax.lax.Precision + | tuple[str, str] + | tuple[jax.lax.Precision, jax.lax.Precision] +) def ConvLayer( @@ -30,6 +37,9 @@ def ConvLayer( kernel_size: int | Sequence[int], padding: nn.linear.PaddingLike, use_local: bool = False, + precision: PrecisionLike = None, + dtype: jnp.dtype = jnp.float32, + param_dtype: jnp.dtype = jnp.float32, **kwargs, ) -> nn.Module: """Factory for different types of convolution layers.""" @@ -44,12 +54,31 @@ def ConvLayer( kernel_size, use_local=use_local, order=padding.lower(), + precision=precision, + dtype=dtype, + param_dtype=param_dtype, **kwargs, ) elif use_local: - return nn.ConvLocal(features, kernel_size, padding=padding, **kwargs) + return nn.ConvLocal( + features, + kernel_size, + padding=padding, + precision=precision, + dtype=dtype, + param_dtype=param_dtype, + **kwargs, + ) else: - return nn.Conv(features, kernel_size, padding=padding, **kwargs) + return nn.Conv( + features, + kernel_size, + padding=padding, + precision=precision, + dtype=dtype, + param_dtype=param_dtype, + **kwargs, + ) class LatLonConv(nn.Module): @@ -64,6 +93,9 @@ class LatLonConv(nn.Module): ) strides: tuple[int, int] = (1, 1) use_local: bool = False + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, inputs: Array) -> Array: @@ -114,6 +146,9 @@ def __call__(self, inputs: Array) -> Array: strides=self.strides, kernel_init=self.kernel_init, padding="VALID", + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(padded_inputs) @@ -126,6 +161,9 @@ class DownsampleConv(nn.Module): kernel_init: nn.initializers.Initializer = nn.initializers.variance_scaling( scale=1.0, mode="fan_avg", distribution="uniform" ) + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, inputs: Array) -> Array: @@ -152,4 +190,7 @@ def __call__(self, inputs: Array) -> Array: strides=self.ratios, kernel_init=self.kernel_init, padding="VALID", + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(inputs) diff --git a/swirl_dynamics/lib/layers/residual.py b/swirl_dynamics/lib/layers/residual.py index fa16d38..afcd861 100644 --- a/swirl_dynamics/lib/layers/residual.py +++ b/swirl_dynamics/lib/layers/residual.py @@ -13,12 +13,18 @@ # limitations under the License. """Residual layer modules.""" - import flax.linen as nn import jax import jax.numpy as jnp Array = jax.Array +PrecisionLike = ( + None + | str + | jax.lax.Precision + | tuple[str, str] + | tuple[jax.lax.Precision, jax.lax.Precision] +) class CombineResidualWithSkip(nn.Module): @@ -31,6 +37,9 @@ class CombineResidualWithSkip(nn.Module): """ project_skip: bool = False + precision: PrecisionLike = None + dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, *, residual: Array, skip: Array) -> Array: @@ -40,5 +49,8 @@ def __call__(self, *, residual: Array, skip: Array) -> Array: kernel_init=nn.initializers.variance_scaling( scale=1.0, mode="fan_avg", distribution="uniform" ), + precision=self.precision, + dtype=self.dtype, + param_dtype=self.param_dtype, )(skip) return (skip + residual) / jnp.sqrt(2.0) diff --git a/swirl_dynamics/lib/layers/resize.py b/swirl_dynamics/lib/layers/resize.py index 770d878..ffd997e 100644 --- a/swirl_dynamics/lib/layers/resize.py +++ b/swirl_dynamics/lib/layers/resize.py @@ -22,6 +22,13 @@ from swirl_dynamics.lib.layers import convolutions Array = jax.Array +PrecisionLike = ( + None + | str + | jax.lax.Precision + | tuple[str, str] + | tuple[jax.lax.Precision, jax.lax.Precision] +) class FilteredResize(nn.Module): @@ -35,7 +42,9 @@ class FilteredResize(nn.Module): 'LATLON', 'LONLAT]. initializer: The initializer for the convolution kernels. use_local: Whether to use unshared weights in the filtering. - dtype: The data type of the input and weights. + precision: Level of precision used in the convolutional layer. + dtype: The data type of the input and output. + params_dtype: The data type of of the weights. """ output_size: Sequence[int] @@ -46,7 +55,9 @@ class FilteredResize(nn.Module): stddev=0.02 ) use_local: bool = False + precision: PrecisionLike = None dtype: jnp.dtype = jnp.float32 + param_dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, inputs: Array) -> Array: @@ -87,5 +98,8 @@ def __call__(self, inputs: Array) -> Array: padding=self.padding, kernel_init=self.initializer, use_local=self.use_local, + dtype=self.dtype, + precision=self.precision, + param_dtype=self.param_dtype, )(resized) return out