Skip to content

Commit

Permalink
supporting defaulting dropout rates to None
Browse files Browse the repository at this point in the history
  • Loading branch information
znado committed Oct 7, 2022
1 parent 0f8ac82 commit 517a442
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 60 deletions.
6 changes: 5 additions & 1 deletion algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ def logging_pass(*args):


# torch.nn.functional.dropout will not be affected by this function.
def update_dropout(model, dropout_rate):
def maybe_update_dropout(model, dropout_rate):
if dropout_rate is None:
return
for child in list(model.modules()):
if isinstance(child, torch.nn.Dropout):
child.p = dropout_rate


def update_attention_dropout(model, attention_dropout_rate):
if attention_dropout_rate is None:
return
for child in list(model.modules()):
if isinstance(child, torch.nn.MultiheadAttention):
child.dropout = attention_dropout_rate
16 changes: 10 additions & 6 deletions algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Data:
github.com/facebookresearch/fastMRI/tree/main/fastmri/data
"""
from typing import Optional

import flax.linen as nn
import jax
Expand Down Expand Up @@ -63,27 +64,30 @@ class UNet(nn.Module):
out_channels: int = 1
channels: int = 32
num_pool_layers: int = 4
dropout_rate: float = 0.0
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.

@nn.compact
def __call__(self, x, train=True):
down_sample_layers = [ConvBlock(self.channels, self.dropout_rate)]
dropout_rate = self.dropout_rate
if dropout_rate is None:
dropout_rate = 0.0
down_sample_layers = [ConvBlock(self.channels, dropout_rate)]

ch = self.channels
for _ in range(self.num_pool_layers - 1):
down_sample_layers.append(ConvBlock(ch * 2, self.dropout_rate))
down_sample_layers.append(ConvBlock(ch * 2, dropout_rate))
ch *= 2
conv = ConvBlock(ch * 2, self.dropout_rate)
conv = ConvBlock(ch * 2, dropout_rate)

up_conv = []
up_transpose_conv = []
for _ in range(self.num_pool_layers - 1):
up_transpose_conv.append(TransposeConvBlock(ch))
up_conv.append(ConvBlock(ch, self.dropout_rate))
up_conv.append(ConvBlock(ch, dropout_rate))
ch //= 2

up_transpose_conv.append(TransposeConvBlock(ch))
up_conv.append(ConvBlock(ch, self.dropout_rate))
up_conv.append(ConvBlock(ch, dropout_rate))

final_conv = nn.Conv(self.out_channels, kernel_size=(1, 1), strides=(1, 1))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def model_fn(
del update_batch_norm

model = params
pytorch_utils.update_dropout(model, dropout_rate)
pytorch_utils.maybe_update_dropout(model, dropout_rate)

if mode == spec.ForwardPassMode.EVAL:
model.eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class ViT(nn.Module):
num_heads: int = 12
posemb: str = 'sincos2d' # Can also be "learn"
rep_size: Union[int, bool] = True
dropout_rate: float = 0.0
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
pool_type: str = 'gap' # Can also be 'map' or 'tok'
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
Expand Down Expand Up @@ -175,14 +175,17 @@ def __call__(self, x, *, train=False):
cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype)
x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)

n, l, c = x.shape # pylint: disable=unused-variable
x = nn.Dropout(rate=self.dropout_rate)(x, not train)
n, _, c = x.shape
dropout_rate = self.dropout_rate
if dropout_rate is None:
dropout_rate = 0.0
x = nn.Dropout(rate=dropout_rate)(x, not train)

x, out['encoder'] = Encoder(
depth=self.depth,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate,
dropout_rate=dropout_rate,
name='Transformer')(
x, train=not train)
encoded = out['encoded'] = x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def model_fn(
del update_batch_norm

model = params
pytorch_utils.update_dropout(model, dropout_rate)
pytorch_utils.maybe_update_dropout(model, dropout_rate)

if mode == spec.ForwardPassMode.EVAL:
model.eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

import math
from typing import Any, List
from typing import Any, List, Optional

from flax import linen as nn
from flax import struct
Expand All @@ -36,10 +36,13 @@ class ConformerConfig:
num_attention_heads: int = 8
num_encoder_layers: int = 4
attention_dropout_rate: float = 0.0
attention_residual_dropout_rate: float = 0.1
conv_residual_dropout_rate: float = 0.0
# If None, defaults to 0.1.
attention_residual_dropout_rate: Optional[float] = 0.1
# If None, defaults to 0.0.
conv_residual_dropout_rate: Optional[float] = 0.0
feed_forward_dropout_rate: float = 0.0
feed_forward_residual_dropout_rate: float = 0.1
# If None, defaults to 0.1.
feed_forward_residual_dropout_rate: Optional[float] = 0.1
convolution_kernel_size: int = 5
feed_forward_expansion_factor: int = 4
freq_mask_count: int = 2
Expand All @@ -49,7 +52,8 @@ class ConformerConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
input_dropout_rate: float = 0.1
# If None, defaults to 0.1.
input_dropout_rate: Optional[float] = 0.1
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001

Expand Down Expand Up @@ -212,7 +216,12 @@ def __call__(self, inputs, padding_mask=None, train=False):
inputs)
inputs = inputs * padding_mask

inputs = nn.Dropout(rate=config.feed_forward_residual_dropout_rate)(
if config.feed_forward_residual_dropout_rate is None:
feed_forward_residual_dropout_rate = 0.1
else:
feed_forward_residual_dropout_rate = (
config.feed_forward_residual_dropout_rate)
inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)(
inputs, deterministic=not train)

return inputs
Expand Down Expand Up @@ -386,9 +395,12 @@ def __call__(self, inputs, paddings, train):
dropout_rate=config.attention_dropout_rate,
deterministic=not train)(inputs, attention_mask)

if config.attention_residual_dropout_rate is None:
attention_residual_dropout_rate = 0.1
else:
attention_residual_dropout_rate = config.attention_residual_dropout_rate
result = nn.Dropout(
rate=config.attention_residual_dropout_rate, deterministic=not train)(
result)
rate=attention_residual_dropout_rate, deterministic=not train)(result)

return result

Expand Down Expand Up @@ -523,9 +535,12 @@ def __call__(self, inputs, input_paddings, train):
config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())(
inputs)

if config.conv_residual_dropout_rate is None:
conv_residual_dropout_rate = 0.0
else:
conv_residual_dropout_rate = config.conv_residual_dropout_rate
inputs = nn.Dropout(
rate=config.conv_residual_dropout_rate, deterministic=not train)(
inputs)
rate=conv_residual_dropout_rate, deterministic=not train)(inputs)
return inputs


Expand Down Expand Up @@ -607,9 +622,13 @@ def __call__(self, inputs, input_paddings, train):
outputs, output_paddings = self.specaug(outputs, output_paddings)

# Subsample input by a factor of 4 by performing strided convolutions.
if config.input_dropout_rate is None:
input_dropout_rate = 0.1
else:
input_dropout_rate = config.input_dropout_rate
outputs, output_paddings = Subsample(
encoder_dim=config.encoder_dim,
input_dropout_rate=config.input_dropout_rate)(
input_dropout_rate=input_dropout_rate)(
outputs, output_paddings, train)

# Run the conformer encoder layers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,22 @@
MAX_INPUT_LENGTH = 320000


def _update_model_dropout(model, residual_dropout_rate, input_dropout_rate):
def _maybe_update_model_dropout(
model, residual_dropout_rate, input_dropout_rate):
for child in list(model.modules()):
# Residual dropout.
if isinstance(child, conformer_model.MultiHeadedSelfAttention):
if (isinstance(child, conformer_model.MultiHeadedSelfAttention) and
residual_dropout_rate is not None):
child.dropout.p = residual_dropout_rate
elif isinstance(child, conformer_model.ConvolutionBlock):
elif (isinstance(child, conformer_model.ConvolutionBlock) and
residual_dropout_rate is not None):
child.dropout.p = residual_dropout_rate
elif isinstance(child, conformer_model.FeedForwardModule):
elif (isinstance(child, conformer_model.FeedForwardModule) and
residual_dropout_rate is not None):
child.dropout2.p = residual_dropout_rate
# Input dropout.
elif isinstance(child, conformer_model.Subsample):
elif (isinstance(child, conformer_model.Subsample) and
input_dropout_rate is not None):
child.dropout.p = input_dropout_rate


Expand Down Expand Up @@ -88,7 +93,7 @@ def model_fn(
del update_batch_norm

model = params
_update_model_dropout(
_maybe_update_model_dropout(
model,
residual_dropout_rate=dropout_rate,
input_dropout_rate=aux_dropout_rate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ class DeepspeechConfig:
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
input_dropout_rate: float = 0.1
feed_forward_dropout_rate: float = 0.1
# If None, defaults to 0.1.
input_dropout_rate: Optional[float] = 0.1
# If None, defaults to 0.1.
feed_forward_dropout_rate: Optional[float] = 0.1
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
Expand Down Expand Up @@ -99,9 +101,12 @@ def __call__(self, inputs, output_paddings, train):
kernel_init=nn.initializers.xavier_uniform())(
outputs)

if config.input_dropout_rate is None:
input_dropout_rate = 0.1
else:
input_dropout_rate = config.input_dropout_rate
outputs = nn.Dropout(
rate=config.input_dropout_rate, deterministic=not train)(
outputs)
rate=input_dropout_rate, deterministic=not train)(outputs)

return outputs, output_paddings

Expand Down Expand Up @@ -188,7 +193,11 @@ def __call__(self, inputs, input_paddings=None, train=False):
inputs = nn.relu(inputs)
inputs *= padding_mask

inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)(
if config.feed_forward_dropout_rate is None:
feed_forward_dropout_rate = 0.1
else:
feed_forward_dropout_rate = config.feed_forward_dropout_rate
inputs = nn.Dropout(rate=feed_forward_dropout_rate)(
inputs, deterministic=not train)

return inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def model_fn(
raise ValueError(
f'Expected model_state to be None, received {model_state}.')
model = params
pytorch_utils.update_dropout(model, dropout_rate)
pytorch_utils.maybe_update_dropout(model, dropout_rate)

if mode == spec.ForwardPassMode.TRAIN:
model.train()
Expand Down
Loading

0 comments on commit 517a442

Please sign in to comment.