diff --git a/swirl_dynamics/lib/diffusion/vivit.py b/swirl_dynamics/lib/diffusion/vivit.py index 3258d7c..70f8c54 100644 --- a/swirl_dynamics/lib/diffusion/vivit.py +++ b/swirl_dynamics/lib/diffusion/vivit.py @@ -51,6 +51,15 @@ 3: 'width', }) +# Permutation to perform a depth to space in three dimension. +# Basically we seek to go from: +# (batch_size, time, height, width, +# time_patch, height_patch, width_patch, emb_dim ) +# to +# (batch_size, time, time_patch, height, height_patch, +# width, width_patch, emb_dim) +_PERMUTATION = (0, 1, 4, 2, 5, 3, 6, 7) + def get_fixed_sincos_position_embedding( x_shape: Shape, @@ -338,12 +347,12 @@ def __call__(self, inputs: Array, *, train: bool) -> tuple[Array, int]: class TemporalDecoder(nn.Module): - """Temporal Decoder. + """Temporal Decoder from latent space to original 3d space. Attributes: patches: The size of each patch used for the temporal embedding. features_out: Number of features in the output. - encoded_shapes: Shape in time, height, and width of the encoded inputs. + encoded_shapes: Shape of the encoded input following [time, height, width]. """ patches: tuple[int, ...] features_out: int @@ -355,15 +364,26 @@ def __call__(self, inputs: Array, *, train: bool) -> Array: # We suppose that the input is batch_size, num_tokens, emb_dim batch_size, _, emb_dim = inputs.shape - time, height, width = self.patches + t, h, w = self.patches + enc_t, enc_h, enc_w = self.encoded_shapes x = jnp.reshape(inputs, (batch_size, *self.encoded_shapes, emb_dim)) - x = nn.ConvTranspose(features=self.features_out, - kernel_size=(2, 2, 2), - strides=(time, height, width), - kernel_dilation=(time, height, width), - transpose_kernel=True, - name='conv_transpose_temporal_decoder')(x) + x = nn.Conv( + features=self.features_out * t * h * w, + kernel_size=(1, 1, 1), + strides=(1, 1, 1), + name='conv_transpose_temporal_decoder', + )(x) + + # TODO(lzepedanunez): Use unets.depth_to_space here instead. + x = jnp.reshape( + x, (batch_size, *self.encoded_shapes, t, h, w, self.features_out) + ) + x = jnp.transpose(x, _PERMUTATION) + x = jnp.reshape( + x, (batch_size, enc_t * t, enc_h * h, enc_w * w, self.features_out) + ) + return x diff --git a/swirl_dynamics/lib/diffusion/vivit_diffusion.py b/swirl_dynamics/lib/diffusion/vivit_diffusion.py index c912e02..10b2f32 100644 --- a/swirl_dynamics/lib/diffusion/vivit_diffusion.py +++ b/swirl_dynamics/lib/diffusion/vivit_diffusion.py @@ -372,9 +372,7 @@ class ViViTDiffusion(nn.Module): attention_dropout_rate: Dropout for attention heads. stochastic_droplayer_rate: Probability of dropping a layer. Linearly increases from 0 to the provided value.. - return_prelogits: If true, return the final representation of the network - before the classification head. Useful when using features for a - downstream task. + positional_embedding: Type of positional encoding. dtype: JAX data type for activations. """ @@ -387,8 +385,10 @@ class ViViTDiffusion(nn.Module): temporal_encoding_config: ml_collections.ConfigDict attention_config: ml_collections.ConfigDict dropout_rate: float = 0.1 + noise_embed_dim: int = 256 attention_dropout_rate: float = 0.1 stochastic_droplayer_rate: float = 0.0 + positional_embedding: str = 'sinusoidal_3d' dtype: jnp.dtype = jnp.float32 @nn.compact @@ -403,7 +403,7 @@ def __call__( batch_size_input, num_frames, height, width, _ = x.shape # Computing the embedding for modulation. - emb = unets.FourierEmbedding()(sigma) + emb = unets.FourierEmbedding(dims=self.noise_embed_dim)(sigma) # Shape: (batch_size, num_frames//patch_time, height//patch_height, # width//patch_width, emd_dim). @@ -431,6 +431,7 @@ def __call__( dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, stochastic_droplayer_rate=self.stochastic_droplayer_rate, + positional_embedding=self.positional_embedding, dtype=self.dtype, name='Transformer')( x, emb, train=is_training) diff --git a/swirl_dynamics/lib/diffusion/vivit_test.py b/swirl_dynamics/lib/diffusion/vivit_test.py index cc4d60f..63df007 100644 --- a/swirl_dynamics/lib/diffusion/vivit_test.py +++ b/swirl_dynamics/lib/diffusion/vivit_test.py @@ -92,6 +92,32 @@ def test_3dfactorized_self_attention_output_shape( self.assertEqual(out.shape, (batch_size, num_tokens, channel)) + @parameterized.parameters( + ((1, 2, 4, 8, 3), (2, 2, 2), 1), + ((1, 2, 8, 8, 3), (2, 2, 2), 2), + ((1, 2, 8, 8, 3), (2, 2, 2), 3), + ) + def test_decoder_output_shape(self, enc_shapes, patch_size, output_features): + batch_size, time, height, width, channels = enc_shapes + + t, h, w = patch_size + x = jax.random.normal(jax.random.PRNGKey(0), enc_shapes) + + decoder = vivit.TemporalDecoder( + patches=patch_size, + features_out=output_features, + encoded_shapes=(time, height, width), + ) + + out, _ = decoder.init_with_output( + jax.random.PRNGKey(42), + x.reshape((batch_size, -1, channels)), + train=False, + ) + + self.assertEqual(out.shape, (batch_size, time * t, height * h, width * w, + output_features)) + if __name__ == "__main__": absltest.main()