diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 834c93b7a..09338ca82 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -28,11 +28,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..019dde38c 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -110,7 +110,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -119,14 +120,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 99a9b0513..2e680cbd9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -84,11 +84,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..46168c2a0 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -148,7 +148,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -157,14 +158,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..077ff0f89 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,7 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +462,12 @@ def __call__(self, inputs, input_paddings, train): momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if train: + if use_running_average_bn: + mean = self.ra_mean.value + var = self.ra_var.value + + else: + # compute batch statistics mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( @@ -477,17 +482,14 @@ def __call__(self, inputs, input_paddings, train): keepdims=True) var = sum_vv / count_v - - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var - else: - mean = self.ra_mean.value - var = self.ra_var.value - + + if update_batch_norm: + self.ra_mean.value = momentum * \ + self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * \ + self.ra_var.value + (1 - momentum) * var + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) - bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -517,7 +519,7 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -546,7 +548,7 @@ def __call__(self, inputs, input_paddings, train): kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, train) + inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -586,7 +588,7 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -597,7 +599,7 @@ def __call__(self, inputs, input_paddings, train): inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train) + ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) @@ -629,12 +631,23 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs output_paddings = input_paddings + # Set BN args if not supplied for backwards compatibility + if update_batch_norm is None: + update_batch_norm = train + if use_running_average_bn is None: + use_running_average_bn = not train + # Compute normalized log mel spectrograms from input audio signal. preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() outputs, output_paddings = preprocessor.MelFilterbankFrontend( @@ -660,7 +673,7 @@ def __call__(self, inputs, input_paddings, train): # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train) + outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..6c55acfb0 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -107,7 +107,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN @@ -118,7 +119,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -126,7 +128,8 @@ def model_fn( inputs, input_paddings, train=False, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), model_state def _build_input_queue(