Skip to content

Commit

Permalink
Merge pull request #798 from mlcommons/bn_fixes_clean
Browse files Browse the repository at this point in the history
BN fixes JAX
  • Loading branch information
priyakasimbeg authored Oct 29, 2024
2 parents 7a3710f + 894cd87 commit 787f7fb
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 44 deletions.
11 changes: 7 additions & 4 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
)
if isinstance(module, bn_layers):
if not update_batch_norm:
module.eval()
module.momentum_backup = module.momentum
if not hasattr(module, 'momentum_backup'):
module.momentum_backup = module.momentum

# module.momentum can be float or torch.Tensor.
module.momentum = 0. * module.momentum_backup
if torch.is_tensor(module.momentum_backup):
module.momentum = torch.zeros_like(module.momentum_backup)
else:
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup
module.track_running_stats = update_batch_norm
9 changes: 7 additions & 2 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ class ResNet(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
update_batch_norm: bool = True,
use_running_average_bn: bool = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
Expand Down
10 changes: 7 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand All @@ -119,14 +121,16 @@ def model_fn(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,16 @@ class ResNet(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
update_batch_norm: bool = True,
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand All @@ -157,14 +159,16 @@ def model_fn(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,24 @@ def setup(self):
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
update_batch_norm,
use_running_average_bn):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))

padding = jnp.expand_dims(input_paddings, -1)
momentum = self.config.batch_norm_momentum
epsilon = self.config.batch_norm_epsilon

if train:
if use_running_average_bn:
mean = self.ra_mean.value
var = self.ra_var.value

else:
# compute batch statistics
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
Expand All @@ -478,16 +487,13 @@ def __call__(self, inputs, input_paddings, train):

var = sum_vv / count_v

self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var
else:
mean = self.ra_mean.value
var = self.ra_var.value
if update_batch_norm:
self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var

inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)

bn_output = (inputs - mean) * inv + self.beta
bn_output *= 1.0 - padding

Expand Down Expand Up @@ -517,7 +523,12 @@ class ConvolutionBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average_bn):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)

Expand Down Expand Up @@ -546,7 +557,10 @@ def __call__(self, inputs, input_paddings, train):
kernel_init=nn.initializers.xavier_uniform())(
inputs)

inputs = BatchNorm(config)(inputs, input_paddings, train)
inputs = BatchNorm(config)(inputs,
input_paddings,
update_batch_norm,
use_running_average_bn)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
Expand Down Expand Up @@ -586,7 +600,12 @@ class ConformerBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm,
use_running_average):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)

Expand All @@ -597,7 +616,12 @@ def __call__(self, inputs, input_paddings, train):
inputs, input_paddings, train)

inputs = inputs + \
ConvolutionBlock(config)(inputs, input_paddings, train)
ConvolutionBlock(config)(inputs,
input_paddings,
train,
update_batch_norm,
use_running_average
)

inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
inputs, padding_mask, train)
Expand Down Expand Up @@ -629,12 +653,23 @@ def setup(self):
.use_dynamic_time_mask_max_frames)

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None):
config = self.config

outputs = inputs
output_paddings = input_paddings

# Set BN args if not supplied for backwards compatibility
if update_batch_norm is None:
update_batch_norm = train
if use_running_average_bn is None:
use_running_average_bn = not train

# Compute normalized log mel spectrograms from input audio signal.
preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
Expand All @@ -660,7 +695,11 @@ def __call__(self, inputs, input_paddings, train):

# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
outputs = ConformerBlock(config)(outputs, output_paddings, train)
outputs = ConformerBlock(config)(outputs,
output_paddings,
train,
update_batch_norm,
use_running_average_bn)

outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
Expand All @@ -118,15 +120,17 @@ def model_fn(
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return (logits, logit_paddings), model_state

def _build_input_queue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ConformerConfig:
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
input_dropout_rate: float = 0.1
batch_norm_momentum: float = 0.999
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
attention_temperature: float = 1.0
Expand Down Expand Up @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings):
mean = (masked_inp).sum(dim=(0, 1)) / count
var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count

self.running_mean = self.momentum * self.running_mean + (
1 - self.momentum) * mean.detach()
self.running_var = self.momentum * self.running_var + (
1 - self.momentum) * var.detach()
self.running_mean = (1 - self.momentum) * self.running_mean + (
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()

else:
mean = self.running_mean
var = self.running_var
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Optional
from typing import Dict, Optional, Tuple

from flax import jax_utils
import jax
Expand Down Expand Up @@ -56,6 +56,37 @@ def init_model_fn(
params = jax_utils.replicate(params)
return params, model_state

def model_fn(
self,
params: spec.ParameterContainer,
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
variables,
inputs,
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
return (logits, logit_paddings), model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DeepspeechConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 0.999
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
# If None, defaults to 0.1.
input_dropout_rate: Optional[float] = 0.1
Expand Down Expand Up @@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings):
sum_ = dist_nn.all_reduce(sum_)
var = sum_ / count

self.running_mean = self.momentum * self.running_mean + (
1 - self.momentum) * mean.detach()
self.running_var = self.momentum * self.running_var + (
1 - self.momentum) * var.detach()
self.running_mean = (1 - self.momentum) * self.running_mean + (
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()
else:
mean = self.running_mean
var = self.running_var
Expand Down
1 change: 1 addition & 0 deletions tests/reference_algorithm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def _test_submission(workload_name,
workload_path=workload_metadata['workload_path'],
workload_class_name=workload_metadata['workload_class_name'],
return_class=True)
print(f'Workload class for {workload_name} is {workload_class}')

submission_module_path = workloads.convert_filepath_to_module(submission_path)
submission_module = importlib.import_module(submission_module_path)
Expand Down

0 comments on commit 787f7fb

Please sign in to comment.