From 1f77160c47c520c8aeb720bd65ecb6b0dd37f339 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leonardo=20Zepeda-N=C3=BA=C3=B1ez?= Date: Mon, 18 Mar 2024 14:19:24 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 616945367 --- swirl_dynamics/lib/diffusion/vivit.py | 34 +++-- .../lib/diffusion/vivit_diffusion.py | 138 +++++++++++++++++- .../lib/diffusion/vivit_diffusion_test.py | 49 ++++++- .../config_kolmogorov_med_res_3d_axial.py | 93 ++++++++++++ 4 files changed, 296 insertions(+), 18 deletions(-) create mode 100644 swirl_dynamics/projects/spatiotemporal_modeling/configs/config_kolmogorov_med_res_3d_axial.py diff --git a/swirl_dynamics/lib/diffusion/vivit.py b/swirl_dynamics/lib/diffusion/vivit.py index 50a12f4..94f24f2 100644 --- a/swirl_dynamics/lib/diffusion/vivit.py +++ b/swirl_dynamics/lib/diffusion/vivit.py @@ -311,7 +311,7 @@ class TemporalEncoder(nn.Module): Attributes: temporal_encoding_config: Dictionary containing some of the options. - patches: The size of each Patch for the temporal embedding. + patches: The size of each patch for the temporal embedding. features_out: Number of features in the output. encoded_shapes: Shape in time, height, and width of the encoded inputs. """ @@ -468,7 +468,7 @@ def _run_attention_on_axis(inputs, axis, two_d_shape): class Encoder3DFactorizedSelfAttentionBlock(nn.Module): - """Encoder with factorized self attention block. + """Encoder with 3D factorized self attention block. Attributes: mlp_dim: Dimension of the mlp on top of attention block. @@ -478,9 +478,10 @@ class Encoder3DFactorizedSelfAttentionBlock(nn.Module): dropout_rate: Dropout rate. attention_dropout_rate: Dropout for attention heads. droplayer_p: Probability of dropping a layer. - attention_order: The order to do the attention. Choice of {time_space, - space_time}. - dtype: the dtype of the computation (default: float32). + attention_order: The order in which the axial attention is performed. You + can choose either `time_space` (time_height_width) or `space_time` + (height_width_time). + dtype: The data type of the computation. """ mlp_dim: int num_heads: int @@ -525,7 +526,16 @@ def __call__(self, inputs: Array, *, deterministic: bool) -> Array: def _run_attention_on_axis( inputs: Array, axis: int, three_d_shape: tuple[int, ...] ) -> Array: - """Reshapes the input and run attention on the given axis.""" + """Reshapes the input and run attention on the given axis. + + Args: + inputs: Input tensor in 3d space (i.e. a 5-tensor). + axis: Index of the axis in which perform the axial attention. + three_d_shape: Original three dimensional shape. + + Returns: + An tensor with the same spatial dimensions as the input. + """ inputs = reshape_utils.reshape_3d_to_1d_factorized(inputs, axis=axis) x = nn.LayerNorm( dtype=self.dtype, name='LayerNorm_{}'.format(_AXIS_TO_NAME_3D[axis]) @@ -688,7 +698,9 @@ def __call__(self, inputs: Array, *, train: bool) -> Array: # TODO: change this one to handle non-square domains. height = width = int(np.sqrt(num_tokens // self.temporal_dims)) if height * width * self.temporal_dims != num_tokens: - raise ValueError('Input is assumed to be square for sinusoidal init.') + raise ValueError('Input is assumed to be square in the ' + 'spatial dimensions for sinusoidal init. Instead the ' + f'dimensions are {height} and {width}.') inputs_reshape = inputs.reshape([batch, self.temporal_dims, height, width, hidden_dim]) @@ -766,7 +778,7 @@ class ViViT(nn.Module): dropout_rate: Dropout rate. attention_dropout_rate: Dropout for attention heads. stochastic_droplayer_rate: Probability of dropping a layer. Linearly - increases from 0 to the provided value.. + increases from 0 to the provided value. dtype: JAX data type for the weights and activation functions. """ @@ -795,10 +807,10 @@ def __call__( hidden_size=self.hidden_size, )(x, train=is_training) - bathc_size, enc_t, enc_h, enc_w, emb_dim = x.shape + batch_size, enc_t, enc_h, enc_w, emb_dim = x.shape num_tokens = enc_t * enc_h * enc_w - x = jnp.reshape(x, (bathc_size, num_tokens, emb_dim)) + x = jnp.reshape(x, (batch_size, num_tokens, emb_dim)) x = TransformerBlock( temporal_dims=temporal_dims, @@ -811,6 +823,8 @@ def __call__( stochastic_droplayer_rate=self.stochastic_droplayer_rate, dtype=self.dtype, name='Transformer', + # TODO: clean this input/remove the temporal dims. + encoded_shape=(batch_size, enc_t, enc_h, enc_w, emb_dim) )(x, train=is_training) x = TemporalDecoder( diff --git a/swirl_dynamics/lib/diffusion/vivit_diffusion.py b/swirl_dynamics/lib/diffusion/vivit_diffusion.py index eee0e2a..5c68e76 100644 --- a/swirl_dynamics/lib/diffusion/vivit_diffusion.py +++ b/swirl_dynamics/lib/diffusion/vivit_diffusion.py @@ -54,6 +54,12 @@ 2: 'space', }) +_AXIS_TO_NAME_3D = dict({ + 1: 'time', + 2: 'height', + 3: 'width', +}) + class EncoderEmbeddingBlock(nn.Module): """Transformer encoder block. @@ -251,6 +257,125 @@ def _run_attention_on_axis( return x + y +class Factorized3DSelfAttentionEmbeddingBlock(nn.Module): + """Encoder with factorized self attention block. + + Attributes: + mlp_dim: Dimension of the mlp on top of attention block. + num_heads: Number of heads. + temporal_dims: Number of temporal dimensions in the flattened input + attention_kernel_initializer: Initializer to use for attention layers. + dropout_rate: Dropout rate. + attention_dropout_rate: Dropout for attention heads. + droplayer_p: Probability of dropping a layer. + attention_order: The order to do the attention. In this case the two choices + are 'time_height_width' and 'height_width_time'. + dtype: the dtype of the computation (default: float32). + """ + mlp_dim: int + num_heads: int + three_dim_shape: tuple[int, int, int, int, int] + attention_kernel_initializer: Initializer + dropout_rate: float = 0.1 + attention_dropout_rate: float = 0.1 + droplayer_p: Optional[float] = None + attention_order: str = 'time_height_width' + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, inputs: Array, emb: Array, *, deterministic: bool): + """Applies Encoder1DBlock module.""" + + batch_size, num_tokens, emb_dim = inputs.shape + _, enc_t, enc_h, enc_w, _ = self.three_dim_shape + + if num_tokens != (enc_t * enc_h * enc_w): + raise ValueError('The product of the encoded dimensions for time, height', + f' and width ( {enc_t}, {enc_h}, {enc_w}) respectively,', + ' should match with the number of of tokens ', + f'({num_tokens}) in the input.') + + inputs = jnp.reshape(inputs, self.three_dim_shape) + + self_attention = functools.partial( + nn.SelfAttention, + num_heads=self.num_heads, + kernel_init=self.attention_kernel_initializer, + broadcast_dropout=False, + dropout_rate=self.attention_dropout_rate, + dtype=self.dtype) + + # Order of the Axial Transformer. + if self.attention_order == 'time_height_width': + attention_axes = (1, 2, 3) + elif self.attention_order == 'height_width_time': + attention_axes = (2, 3, 1) + else: + raise ValueError(f'Invalid attention order {self.attention_order}.') + + def _run_attention_on_axis( + inputs: Array, + emb: Array, + axis: int, + three_dim_shape: tuple[int, int, int, int, int], + ): + """Reshapes the input and run attention on the given axis.""" + # shape: (batch, num_tokens, emb_dim) + inputs = reshape_utils.reshape_3d_to_1d_factorized(inputs, axis=axis) + x = nn.LayerNorm( + dtype=self.dtype, name='LayerNorm_{}'.format(_AXIS_TO_NAME_3D[axis]) + )(inputs) + # Add first FiLM layer, some reshaping is necessary. (see Fig. 3 in [1]). + in_shape = x.shape + x = unets.AdaptiveScale()( + x.reshape(three_dim_shape[0], -1, three_dim_shape[-1]), emb + ).reshape(in_shape) + + x = self_attention( + name='MultiHeadDotProductAttention_{}'.format(_AXIS_TO_NAME_3D[axis]) + )(x, deterministic=deterministic) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic) + + # Second FiLM layer (see Fig. 3 in [1]). + in_shape = x.shape + x = unets.AdaptiveScale()( + x.reshape(three_dim_shape[0], -1, three_dim_shape[-1]), emb + ).reshape(in_shape) + + x = x + inputs + + # shape: (batch, num_frames, hw, emb_dim) + return reshape_utils.reshape_to_3d_factorized( + x, axis=axis, three_d_shape=three_dim_shape + ) + + x = inputs + three_dim_shape = inputs.shape + + # shape: (batch, num_frames, hw, emb_dim) + for axis in attention_axes: + x = _run_attention_on_axis(x, emb, axis, three_dim_shape) + + # MLP block. + x = jnp.reshape(x, (batch_size, num_tokens, emb_dim)) + y = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_mlp')(x) + # Add FiLM layer before the attention layer (see Fig. 3 in [1]). + y = unets.AdaptiveScale()(y, emb) + y = vivit.MlpBlock( + mlp_dim=self.mlp_dim, + dtype=self.dtype, + dropout_rate=self.dropout_rate, + activation_fn=nn.gelu, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.normal(stddev=1e-6), + name='MlpBlock')( + y, deterministic=deterministic) + # Add FiLM layer after the attention layer (see Fig. 3 in [1]). + y = unets.AdaptiveScale()(y, emb) + + return x + y + + class TransformerEmbeddingBlock(nn.Module): """Transformer Block with embeddings. @@ -269,6 +394,8 @@ class TransformerEmbeddingBlock(nn.Module): uses independent dropping patterns for each skip-connection. positional_embedding: The type of positional embedding to use. Supported values are {learned_1d, sinusoidal_1d, sinusoidal_3d, none}. + encoded_shape: Three-dimensional shapes of the equivalent tensor. This is + required for the three-dimensional axial transformer. normalise_output: If True, perform layernorm on the output. """ @@ -282,6 +409,7 @@ class TransformerEmbeddingBlock(nn.Module): stochastic_droplayer_rate: float = 0.0 dtype: jnp.dtype = jnp.float32 positional_embedding: str = 'sinusoidal_3d' + encoded_shape: Optional[tuple[int, ...]] | None = None normalise_output: bool = True @nn.compact @@ -324,7 +452,14 @@ def __call__(self, inputs: Array, emb: Array, *, train: bool) -> Array: self.attention_config.get('attention_kernel_init_method', 'xavier')], # pytype: disable=attribute-error temporal_dims=self.temporal_dims) - # TODO: implement factorized_dot_product_attention. + elif self.attention_config.type == 'factorized_3d_self_attention_block': # pytype: disable=attribute-error + encoder_block = functools.partial( + Factorized3DSelfAttentionEmbeddingBlock, + attention_order=self.attention_config.attention_order, # pytype: disable=attribute-error + attention_kernel_initializer=_KERNEL_INITIALIZERS[ + self.attention_config.get('attention_kernel_init_method', # pytype: disable=attribute-error + 'xavier')], + three_dim_shape=self.encoded_shape) else: raise ValueError(f'Unknown attention type {self.attention_config.type}') # pytype: disable=attribute-error @@ -447,6 +582,7 @@ def __call__( stochastic_droplayer_rate=self.stochastic_droplayer_rate, positional_embedding=self.positional_embedding, dtype=self.dtype, + encoded_shape=(batch_size, enc_t, enc_h, enc_w, emb_dim), name='Transformer')( x, emb, train=is_training) diff --git a/swirl_dynamics/lib/diffusion/vivit_diffusion_test.py b/swirl_dynamics/lib/diffusion/vivit_diffusion_test.py index 4d4012a..e4fe79d 100644 --- a/swirl_dynamics/lib/diffusion/vivit_diffusion_test.py +++ b/swirl_dynamics/lib/diffusion/vivit_diffusion_test.py @@ -24,14 +24,49 @@ class VivitDiffusionTest(parameterized.TestCase): @parameterized.parameters( - ((16, 32, 32), (2, 2, 2), 1), - ((16, 64, 64), (4, 4, 4), 1,), - ((32, 64, 64), (8, 8, 8), 2,), + ( + (16, 32, 32), + (2, 2, 2), + 1, + 3, + 'factorized_self_attention_block', + 'time_space', + ), + ( + (16, 64, 64), + (4, 4, 4), + 1, + 3, + 'factorized_self_attention_block', + 'space_time', + ), + ( + (32, 64, 64), + (8, 8, 8), + 2, + 6, + 'factorized_3d_self_attention_block', + 'time_height_width', + ), + ( + (32, 32, 32), + (4, 4, 4), + 2, + 6, + 'factorized_3d_self_attention_block', + 'height_width_time', + ), ) def test_vivit_diffusion_output_shape( - self, spatial_dims, patch_size, output_features + self, + spatial_dims, + patch_size, + output_features, + channels, + attention_type, + attention_order, ): - batch, channels = 2, 3 + batch = 2 x = np.random.randn(batch, *spatial_dims, channels) sigma = np.linspace(0, 1, batch) @@ -44,8 +79,8 @@ def test_vivit_diffusion_output_shape( temporal_encoding_config.kernel_init_method = 'central_frame_initializer' attention_config = ml_collections.ConfigDict() - attention_config.type = 'factorized_self_attention_block' - attention_config.attention_order = 'time_space' + attention_config.type = attention_type + attention_config.attention_order = attention_order attention_config.attention_kernel_init_method = 'xavier' vivit_model = vivit_diffusion.ViViTDiffusion( diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/configs/config_kolmogorov_med_res_3d_axial.py b/swirl_dynamics/projects/spatiotemporal_modeling/configs/config_kolmogorov_med_res_3d_axial.py new file mode 100644 index 0000000..f1bac38 --- /dev/null +++ b/swirl_dynamics/projects/spatiotemporal_modeling/configs/config_kolmogorov_med_res_3d_axial.py @@ -0,0 +1,93 @@ +# Copyright 2024 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config file for ViViT Denoiser. + + +""" + +import ml_collections +# pylint: disable=line-too-long +DATA_PATH = '/datasets/hdf5/pde/2d/ns/attractor_spectral_grid_256_spatial_downsample_4_dt_0.001_v0_3_warmup_40.0_t_final_200.0_nu_0.001_n_samples_2000_ntraj_train_256_ntraj_eval_32_ntraj_test_32_drag_0.1_wave_number_4_random_seeds_combined_4.hdf5' +# pylint: enable=line-too-long + + +def get_config(): + """Returns the base experiment configuration.""" + config = ml_collections.ConfigDict() + + # Model. + # TODO undo all the nested dictionaries. + config.model_name = 'ViViT Denoiser' + config.model = ml_collections.ConfigDict() + config.model.hidden_size = 576 + config.spatial_downsample_factor = 1 + + config.model.num_heads = 18 + config.model.mlp_dim = 512 + config.model.num_layers = 6 + config.model.dropout_rate = 0.3 + config.model_dtype_str = 'float32' + config.model.noise_embed_dim = 256 + config.model.diffusion_scheme = 'variance_exploding' + + config.save_interval_steps = 1000 + config.max_checkpoints_to_keep = 10 + + # TODO: create custom data structures. + config.model.temporal_encoding_config = ml_collections.ConfigDict() + config.model.temporal_encoding_config.method = '3d_conv' + # pylint: disable=line-too-long + config.model.temporal_encoding_config.kernel_init_method = 'central_frame_initializer' + # pylint: enable=line-too-long + config.model.positional_embedding = 'sinusoidal_3d' # 'sinusoidal_3d' + + # TODO: patches doesn't need to be a dictionary. + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (4, 4, 4) # (time, height, width) + + config.model.attention_config = ml_collections.ConfigDict() + # config.model.attention_config.type = 'factorized_encoder' + config.model.attention_config.type = 'factorized_3d_self_attention_block' + config.model.attention_config.attention_order = 'time_height_width' + config.model.attention_config.attention_kernel_init_method = 'xavier' + + config.data = ml_collections.ConfigDict() + config.data.file_path_data = DATA_PATH + config.data.num_time_steps = 32 + config.data.time_stride = 2 + config.data.batch_size = 8 + config.data.normalize = True + config.data.random_seed = 1 + config.data.tf_lookup_batch_size = 32 + config.data.std = 1.0 + + config.optimizer = ml_collections.ConfigDict() + config.optimizer.num_train_steps = 1000000 + config.optimizer.initial_lr = 0.0 + config.optimizer.peak_lr = 3e-4 + config.optimizer.warmup_steps = 50000 + config.optimizer.end_lr = 1e-6 + config.optimizer.ema_decay = 0.999 + config.optimizer.ckpt_interval = 1000 + config.optimizer.max_ckpt_to_keep = 5 + config.optimizer.clip_min = 1e-4 + config.optimizer.metric_aggreration_steps = 50 + config.optimizer.eval_every_steps = 1000 + config.optimizer.num_batches_per_eval = 8 + config.optimizer.clip = 1. + config.optimizer.beta1 = 0.99 + + return config +