Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616945367
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Mar 18, 2024
1 parent b4757fd commit 1f77160
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 18 deletions.
34 changes: 24 additions & 10 deletions swirl_dynamics/lib/diffusion/vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class TemporalEncoder(nn.Module):
Attributes:
temporal_encoding_config: Dictionary containing some of the options.
patches: The size of each Patch for the temporal embedding.
patches: The size of each patch for the temporal embedding.
features_out: Number of features in the output.
encoded_shapes: Shape in time, height, and width of the encoded inputs.
"""
Expand Down Expand Up @@ -468,7 +468,7 @@ def _run_attention_on_axis(inputs, axis, two_d_shape):


class Encoder3DFactorizedSelfAttentionBlock(nn.Module):
"""Encoder with factorized self attention block.
"""Encoder with 3D factorized self attention block.
Attributes:
mlp_dim: Dimension of the mlp on top of attention block.
Expand All @@ -478,9 +478,10 @@ class Encoder3DFactorizedSelfAttentionBlock(nn.Module):
dropout_rate: Dropout rate.
attention_dropout_rate: Dropout for attention heads.
droplayer_p: Probability of dropping a layer.
attention_order: The order to do the attention. Choice of {time_space,
space_time}.
dtype: the dtype of the computation (default: float32).
attention_order: The order in which the axial attention is performed. You
can choose either `time_space` (time_height_width) or `space_time`
(height_width_time).
dtype: The data type of the computation.
"""
mlp_dim: int
num_heads: int
Expand Down Expand Up @@ -525,7 +526,16 @@ def __call__(self, inputs: Array, *, deterministic: bool) -> Array:
def _run_attention_on_axis(
inputs: Array, axis: int, three_d_shape: tuple[int, ...]
) -> Array:
"""Reshapes the input and run attention on the given axis."""
"""Reshapes the input and run attention on the given axis.
Args:
inputs: Input tensor in 3d space (i.e. a 5-tensor).
axis: Index of the axis in which perform the axial attention.
three_d_shape: Original three dimensional shape.
Returns:
An tensor with the same spatial dimensions as the input.
"""
inputs = reshape_utils.reshape_3d_to_1d_factorized(inputs, axis=axis)
x = nn.LayerNorm(
dtype=self.dtype, name='LayerNorm_{}'.format(_AXIS_TO_NAME_3D[axis])
Expand Down Expand Up @@ -688,7 +698,9 @@ def __call__(self, inputs: Array, *, train: bool) -> Array:
# TODO: change this one to handle non-square domains.
height = width = int(np.sqrt(num_tokens // self.temporal_dims))
if height * width * self.temporal_dims != num_tokens:
raise ValueError('Input is assumed to be square for sinusoidal init.')
raise ValueError('Input is assumed to be square in the '
'spatial dimensions for sinusoidal init. Instead the '
f'dimensions are {height} and {width}.')

inputs_reshape = inputs.reshape([batch, self.temporal_dims, height, width,
hidden_dim])
Expand Down Expand Up @@ -766,7 +778,7 @@ class ViViT(nn.Module):
dropout_rate: Dropout rate.
attention_dropout_rate: Dropout for attention heads.
stochastic_droplayer_rate: Probability of dropping a layer. Linearly
increases from 0 to the provided value..
increases from 0 to the provided value.
dtype: JAX data type for the weights and activation functions.
"""

Expand Down Expand Up @@ -795,10 +807,10 @@ def __call__(
hidden_size=self.hidden_size,
)(x, train=is_training)

bathc_size, enc_t, enc_h, enc_w, emb_dim = x.shape
batch_size, enc_t, enc_h, enc_w, emb_dim = x.shape
num_tokens = enc_t * enc_h * enc_w

x = jnp.reshape(x, (bathc_size, num_tokens, emb_dim))
x = jnp.reshape(x, (batch_size, num_tokens, emb_dim))

x = TransformerBlock(
temporal_dims=temporal_dims,
Expand All @@ -811,6 +823,8 @@ def __call__(
stochastic_droplayer_rate=self.stochastic_droplayer_rate,
dtype=self.dtype,
name='Transformer',
# TODO: clean this input/remove the temporal dims.
encoded_shape=(batch_size, enc_t, enc_h, enc_w, emb_dim)
)(x, train=is_training)

x = TemporalDecoder(
Expand Down
138 changes: 137 additions & 1 deletion swirl_dynamics/lib/diffusion/vivit_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@
2: 'space',
})

_AXIS_TO_NAME_3D = dict({
1: 'time',
2: 'height',
3: 'width',
})


class EncoderEmbeddingBlock(nn.Module):
"""Transformer encoder block.
Expand Down Expand Up @@ -251,6 +257,125 @@ def _run_attention_on_axis(
return x + y


class Factorized3DSelfAttentionEmbeddingBlock(nn.Module):
"""Encoder with factorized self attention block.
Attributes:
mlp_dim: Dimension of the mlp on top of attention block.
num_heads: Number of heads.
temporal_dims: Number of temporal dimensions in the flattened input
attention_kernel_initializer: Initializer to use for attention layers.
dropout_rate: Dropout rate.
attention_dropout_rate: Dropout for attention heads.
droplayer_p: Probability of dropping a layer.
attention_order: The order to do the attention. In this case the two choices
are 'time_height_width' and 'height_width_time'.
dtype: the dtype of the computation (default: float32).
"""
mlp_dim: int
num_heads: int
three_dim_shape: tuple[int, int, int, int, int]
attention_kernel_initializer: Initializer
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
droplayer_p: Optional[float] = None
attention_order: str = 'time_height_width'
dtype: jnp.dtype = jnp.float32

@nn.compact
def __call__(self, inputs: Array, emb: Array, *, deterministic: bool):
"""Applies Encoder1DBlock module."""

batch_size, num_tokens, emb_dim = inputs.shape
_, enc_t, enc_h, enc_w, _ = self.three_dim_shape

if num_tokens != (enc_t * enc_h * enc_w):
raise ValueError('The product of the encoded dimensions for time, height',
f' and width ( {enc_t}, {enc_h}, {enc_w}) respectively,',
' should match with the number of of tokens ',
f'({num_tokens}) in the input.')

inputs = jnp.reshape(inputs, self.three_dim_shape)

self_attention = functools.partial(
nn.SelfAttention,
num_heads=self.num_heads,
kernel_init=self.attention_kernel_initializer,
broadcast_dropout=False,
dropout_rate=self.attention_dropout_rate,
dtype=self.dtype)

# Order of the Axial Transformer.
if self.attention_order == 'time_height_width':
attention_axes = (1, 2, 3)
elif self.attention_order == 'height_width_time':
attention_axes = (2, 3, 1)
else:
raise ValueError(f'Invalid attention order {self.attention_order}.')

def _run_attention_on_axis(
inputs: Array,
emb: Array,
axis: int,
three_dim_shape: tuple[int, int, int, int, int],
):
"""Reshapes the input and run attention on the given axis."""
# shape: (batch, num_tokens, emb_dim)
inputs = reshape_utils.reshape_3d_to_1d_factorized(inputs, axis=axis)
x = nn.LayerNorm(
dtype=self.dtype, name='LayerNorm_{}'.format(_AXIS_TO_NAME_3D[axis])
)(inputs)
# Add first FiLM layer, some reshaping is necessary. (see Fig. 3 in [1]).
in_shape = x.shape
x = unets.AdaptiveScale()(
x.reshape(three_dim_shape[0], -1, three_dim_shape[-1]), emb
).reshape(in_shape)

x = self_attention(
name='MultiHeadDotProductAttention_{}'.format(_AXIS_TO_NAME_3D[axis])
)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic)

# Second FiLM layer (see Fig. 3 in [1]).
in_shape = x.shape
x = unets.AdaptiveScale()(
x.reshape(three_dim_shape[0], -1, three_dim_shape[-1]), emb
).reshape(in_shape)

x = x + inputs

# shape: (batch, num_frames, hw, emb_dim)
return reshape_utils.reshape_to_3d_factorized(
x, axis=axis, three_d_shape=three_dim_shape
)

x = inputs
three_dim_shape = inputs.shape

# shape: (batch, num_frames, hw, emb_dim)
for axis in attention_axes:
x = _run_attention_on_axis(x, emb, axis, three_dim_shape)

# MLP block.
x = jnp.reshape(x, (batch_size, num_tokens, emb_dim))
y = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_mlp')(x)
# Add FiLM layer before the attention layer (see Fig. 3 in [1]).
y = unets.AdaptiveScale()(y, emb)
y = vivit.MlpBlock(
mlp_dim=self.mlp_dim,
dtype=self.dtype,
dropout_rate=self.dropout_rate,
activation_fn=nn.gelu,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.normal(stddev=1e-6),
name='MlpBlock')(
y, deterministic=deterministic)
# Add FiLM layer after the attention layer (see Fig. 3 in [1]).
y = unets.AdaptiveScale()(y, emb)

return x + y


class TransformerEmbeddingBlock(nn.Module):
"""Transformer Block with embeddings.
Expand All @@ -269,6 +394,8 @@ class TransformerEmbeddingBlock(nn.Module):
uses independent dropping patterns for each skip-connection.
positional_embedding: The type of positional embedding to use. Supported
values are {learned_1d, sinusoidal_1d, sinusoidal_3d, none}.
encoded_shape: Three-dimensional shapes of the equivalent tensor. This is
required for the three-dimensional axial transformer.
normalise_output: If True, perform layernorm on the output.
"""

Expand All @@ -282,6 +409,7 @@ class TransformerEmbeddingBlock(nn.Module):
stochastic_droplayer_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
positional_embedding: str = 'sinusoidal_3d'
encoded_shape: Optional[tuple[int, ...]] | None = None
normalise_output: bool = True

@nn.compact
Expand Down Expand Up @@ -324,7 +452,14 @@ def __call__(self, inputs: Array, emb: Array, *, train: bool) -> Array:
self.attention_config.get('attention_kernel_init_method',
'xavier')], # pytype: disable=attribute-error
temporal_dims=self.temporal_dims)
# TODO: implement factorized_dot_product_attention.
elif self.attention_config.type == 'factorized_3d_self_attention_block': # pytype: disable=attribute-error
encoder_block = functools.partial(
Factorized3DSelfAttentionEmbeddingBlock,
attention_order=self.attention_config.attention_order, # pytype: disable=attribute-error
attention_kernel_initializer=_KERNEL_INITIALIZERS[
self.attention_config.get('attention_kernel_init_method', # pytype: disable=attribute-error
'xavier')],
three_dim_shape=self.encoded_shape)
else:
raise ValueError(f'Unknown attention type {self.attention_config.type}') # pytype: disable=attribute-error

Expand Down Expand Up @@ -447,6 +582,7 @@ def __call__(
stochastic_droplayer_rate=self.stochastic_droplayer_rate,
positional_embedding=self.positional_embedding,
dtype=self.dtype,
encoded_shape=(batch_size, enc_t, enc_h, enc_w, emb_dim),
name='Transformer')(
x, emb, train=is_training)

Expand Down
49 changes: 42 additions & 7 deletions swirl_dynamics/lib/diffusion/vivit_diffusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,49 @@
class VivitDiffusionTest(parameterized.TestCase):

@parameterized.parameters(
((16, 32, 32), (2, 2, 2), 1),
((16, 64, 64), (4, 4, 4), 1,),
((32, 64, 64), (8, 8, 8), 2,),
(
(16, 32, 32),
(2, 2, 2),
1,
3,
'factorized_self_attention_block',
'time_space',
),
(
(16, 64, 64),
(4, 4, 4),
1,
3,
'factorized_self_attention_block',
'space_time',
),
(
(32, 64, 64),
(8, 8, 8),
2,
6,
'factorized_3d_self_attention_block',
'time_height_width',
),
(
(32, 32, 32),
(4, 4, 4),
2,
6,
'factorized_3d_self_attention_block',
'height_width_time',
),
)
def test_vivit_diffusion_output_shape(
self, spatial_dims, patch_size, output_features
self,
spatial_dims,
patch_size,
output_features,
channels,
attention_type,
attention_order,
):
batch, channels = 2, 3
batch = 2

x = np.random.randn(batch, *spatial_dims, channels)
sigma = np.linspace(0, 1, batch)
Expand All @@ -44,8 +79,8 @@ def test_vivit_diffusion_output_shape(
temporal_encoding_config.kernel_init_method = 'central_frame_initializer'

attention_config = ml_collections.ConfigDict()
attention_config.type = 'factorized_self_attention_block'
attention_config.attention_order = 'time_space'
attention_config.type = attention_type
attention_config.attention_order = attention_order
attention_config.attention_kernel_init_method = 'xavier'

vivit_model = vivit_diffusion.ViViTDiffusion(
Expand Down
Loading

0 comments on commit 1f77160

Please sign in to comment.