Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588279624
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Dec 6, 2023
1 parent 85190d5 commit de2197c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
38 changes: 29 additions & 9 deletions swirl_dynamics/lib/diffusion/vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@
3: 'width',
})

# Permutation to perform a depth to space in three dimension.
# Basically we seek to go from:
# (batch_size, time, height, width,
# time_patch, height_patch, width_patch, emb_dim )
# to
# (batch_size, time, time_patch, height, height_patch,
# width, width_patch, emb_dim)
_PERMUTATION = (0, 1, 4, 2, 5, 3, 6, 7)


def get_fixed_sincos_position_embedding(
x_shape: Shape,
Expand Down Expand Up @@ -338,12 +347,12 @@ def __call__(self, inputs: Array, *, train: bool) -> tuple[Array, int]:


class TemporalDecoder(nn.Module):
"""Temporal Decoder.
"""Temporal Decoder from latent space to original 3d space.
Attributes:
patches: The size of each patch used for the temporal embedding.
features_out: Number of features in the output.
encoded_shapes: Shape in time, height, and width of the encoded inputs.
encoded_shapes: Shape of the encoded input following [time, height, width].
"""
patches: tuple[int, ...]
features_out: int
Expand All @@ -355,15 +364,26 @@ def __call__(self, inputs: Array, *, train: bool) -> Array:
# We suppose that the input is batch_size, num_tokens, emb_dim
batch_size, _, emb_dim = inputs.shape

time, height, width = self.patches
t, h, w = self.patches
enc_t, enc_h, enc_w = self.encoded_shapes
x = jnp.reshape(inputs, (batch_size, *self.encoded_shapes, emb_dim))

x = nn.ConvTranspose(features=self.features_out,
kernel_size=(2, 2, 2),
strides=(time, height, width),
kernel_dilation=(time, height, width),
transpose_kernel=True,
name='conv_transpose_temporal_decoder')(x)
x = nn.Conv(
features=self.features_out * t * h * w,
kernel_size=(1, 1, 1),
strides=(1, 1, 1),
name='conv_transpose_temporal_decoder',
)(x)

# TODO(lzepedanunez): Use unets.depth_to_space here instead.
x = jnp.reshape(
x, (batch_size, *self.encoded_shapes, t, h, w, self.features_out)
)
x = jnp.transpose(x, _PERMUTATION)
x = jnp.reshape(
x, (batch_size, enc_t * t, enc_h * h, enc_w * w, self.features_out)
)

return x


Expand Down
9 changes: 5 additions & 4 deletions swirl_dynamics/lib/diffusion/vivit_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,7 @@ class ViViTDiffusion(nn.Module):
attention_dropout_rate: Dropout for attention heads.
stochastic_droplayer_rate: Probability of dropping a layer. Linearly
increases from 0 to the provided value..
return_prelogits: If true, return the final representation of the network
before the classification head. Useful when using features for a
downstream task.
positional_embedding: Type of positional encoding.
dtype: JAX data type for activations.
"""

Expand All @@ -387,8 +385,10 @@ class ViViTDiffusion(nn.Module):
temporal_encoding_config: ml_collections.ConfigDict
attention_config: ml_collections.ConfigDict
dropout_rate: float = 0.1
noise_embed_dim: int = 256
attention_dropout_rate: float = 0.1
stochastic_droplayer_rate: float = 0.0
positional_embedding: str = 'sinusoidal_3d'
dtype: jnp.dtype = jnp.float32

@nn.compact
Expand All @@ -403,7 +403,7 @@ def __call__(
batch_size_input, num_frames, height, width, _ = x.shape

# Computing the embedding for modulation.
emb = unets.FourierEmbedding()(sigma)
emb = unets.FourierEmbedding(dims=self.noise_embed_dim)(sigma)

# Shape: (batch_size, num_frames//patch_time, height//patch_height,
# width//patch_width, emd_dim).
Expand Down Expand Up @@ -431,6 +431,7 @@ def __call__(
dropout_rate=self.dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
stochastic_droplayer_rate=self.stochastic_droplayer_rate,
positional_embedding=self.positional_embedding,
dtype=self.dtype,
name='Transformer')(
x, emb, train=is_training)
Expand Down
26 changes: 26 additions & 0 deletions swirl_dynamics/lib/diffusion/vivit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,32 @@ def test_3dfactorized_self_attention_output_shape(

self.assertEqual(out.shape, (batch_size, num_tokens, channel))

@parameterized.parameters(
((1, 2, 4, 8, 3), (2, 2, 2), 1),
((1, 2, 8, 8, 3), (2, 2, 2), 2),
((1, 2, 8, 8, 3), (2, 2, 2), 3),
)
def test_decoder_output_shape(self, enc_shapes, patch_size, output_features):
batch_size, time, height, width, channels = enc_shapes

t, h, w = patch_size
x = jax.random.normal(jax.random.PRNGKey(0), enc_shapes)

decoder = vivit.TemporalDecoder(
patches=patch_size,
features_out=output_features,
encoded_shapes=(time, height, width),
)

out, _ = decoder.init_with_output(
jax.random.PRNGKey(42),
x.reshape((batch_size, -1, channels)),
train=False,
)

self.assertEqual(out.shape, (batch_size, time * t, height * h, width * w,
output_features))


if __name__ == "__main__":
absltest.main()

0 comments on commit de2197c

Please sign in to comment.