From b24812ff74d1353a2d56d3cffb86952298836f04 Mon Sep 17 00:00:00 2001 From: Aaron Defazio Date: Thu, 5 Sep 2024 15:20:04 +0000 Subject: [PATCH 01/11] BN Fixes --- algorithmic_efficiency/pytorch_utils.py | 14 ++++++++------ .../librispeech_pytorch/models.py | 11 ++++++----- .../librispeech_pytorch/models.py | 10 +++++----- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 4f6c254bd..2e5828912 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -57,7 +57,6 @@ def sync_ddp_time(time: float, device: torch.device) -> float: dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) return time_tensor.item() - def update_batch_norm_fn(module: spec.ParameterContainer, update_batch_norm: bool) -> None: bn_layers = ( @@ -67,10 +66,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer, ) if isinstance(module, bn_layers): if not update_batch_norm: - module.eval() - module.momentum_backup = module.momentum + if not hasattr(module, 'momentum_backup'): + module.momentum_backup = module.momentum + # module.momentum can be float or torch.Tensor. - module.momentum = 0. * module.momentum_backup + if torch.is_tensor(module.momentum_backup): + module.momentum = torch.zeros_like(module.momentum_backup) + else: + module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): - module.momentum = module.momentum_backup - module.track_running_stats = update_batch_norm + module.momentum = module.momentum_backup \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 502cb093e..cab73df4a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -40,7 +40,7 @@ class ConformerConfig: time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True input_dropout_rate: float = 0.1 - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True attention_temperature: float = 1.0 @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings): mean = (masked_inp).sum(dim=(0, 1)) / count var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + else: mean = self.running_mean var = self.running_var diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index a5ee3fa0a..bdf556f1c 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -36,7 +36,7 @@ class DeepspeechConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 # If None, defaults to 0.1. input_dropout_rate: Optional[float] = 0.1 @@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings): sum_ = dist_nn.all_reduce(sum_) var = sum_ / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() else: mean = self.running_mean var = self.running_var From f574bf04dda725f790bb6ffaf2ca62b260b132d8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 01:01:14 +0000 Subject: [PATCH 02/11] add use_running_average_bn arg for jax --- .../workloads/cifar/cifar_jax/models.py | 9 +++- .../workloads/cifar/cifar_jax/workload.py | 9 ++-- .../imagenet_resnet/imagenet_jax/models.py | 9 +++- .../imagenet_resnet/imagenet_jax/workload.py | 9 ++-- .../librispeech_jax/models.py | 49 ++++++++++++------- .../librispeech_jax/workload.py | 9 ++-- 6 files changed, 63 insertions(+), 31 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 834c93b7a..09338ca82 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -28,11 +28,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..019dde38c 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -110,7 +110,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -119,14 +120,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 99a9b0513..2e680cbd9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -84,11 +84,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..46168c2a0 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -148,7 +148,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -157,14 +158,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..077ff0f89 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,7 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +462,12 @@ def __call__(self, inputs, input_paddings, train): momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if train: + if use_running_average_bn: + mean = self.ra_mean.value + var = self.ra_var.value + + else: + # compute batch statistics mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( @@ -477,17 +482,14 @@ def __call__(self, inputs, input_paddings, train): keepdims=True) var = sum_vv / count_v - - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var - else: - mean = self.ra_mean.value - var = self.ra_var.value - + + if update_batch_norm: + self.ra_mean.value = momentum * \ + self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * \ + self.ra_var.value + (1 - momentum) * var + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) - bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -517,7 +519,7 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -546,7 +548,7 @@ def __call__(self, inputs, input_paddings, train): kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, train) + inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -586,7 +588,7 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -597,7 +599,7 @@ def __call__(self, inputs, input_paddings, train): inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train) + ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) @@ -629,12 +631,23 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs output_paddings = input_paddings + # Set BN args if not supplied for backwards compatibility + if update_batch_norm is None: + update_batch_norm = train + if use_running_average_bn is None: + use_running_average_bn = not train + # Compute normalized log mel spectrograms from input audio signal. preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() outputs, output_paddings = preprocessor.MelFilterbankFrontend( @@ -660,7 +673,7 @@ def __call__(self, inputs, input_paddings, train): # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train) + outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..6c55acfb0 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -107,7 +107,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN @@ -118,7 +119,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -126,7 +128,8 @@ def model_fn( inputs, input_paddings, train=False, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), model_state def _build_input_queue( From 7ca8365a7aba462181736b8d39382162c9bb1ad6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 22:01:58 +0000 Subject: [PATCH 03/11] formatting --- algorithmic_efficiency/pytorch_utils.py | 3 +- .../workloads/cifar/cifar_jax/models.py | 6 +-- .../workloads/cifar/cifar_jax/workload.py | 3 +- .../imagenet_resnet/imagenet_jax/models.py | 4 +- .../imagenet_resnet/imagenet_jax/workload.py | 3 +- .../librispeech_jax/models.py | 49 +++++++++++++------ .../librispeech_jax/workload.py | 3 +- .../librispeech_pytorch/models.py | 2 +- 8 files changed, 49 insertions(+), 24 deletions(-) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 2e5828912..590f500fa 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -57,6 +57,7 @@ def sync_ddp_time(time: float, device: torch.device) -> float: dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) return time_tensor.item() + def update_batch_norm_fn(module: spec.ParameterContainer, update_batch_norm: bool) -> None: bn_layers = ( @@ -75,4 +76,4 @@ def update_batch_norm_fn(module: spec.ParameterContainer, else: module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): - module.momentum = module.momentum_backup \ No newline at end of file + module.momentum = module.momentum_backup diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 09338ca82..059352fb6 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -31,10 +31,10 @@ def __call__(self, update_batch_norm: bool = True, use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - - # Preserve default behavior for backwards compatibility + + # Preserve default behavior for backwards compatibility if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, use_running_average=use_running_average_bn, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 019dde38c..8268c6ca3 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -111,7 +111,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 2e680cbd9..34cd17440 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -88,9 +88,9 @@ def __call__(self, use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - # Preserve default behavior for backwards compatibility + # Preserve default behavior for backwards compatibility if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, use_running_average=use_running_average_bn, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 46168c2a0..2747fc2db 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -149,7 +149,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index 077ff0f89..db92f56d4 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,11 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn): + def __call__(self, + inputs, + input_paddings, + update_batch_norm, + use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +466,7 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if use_running_average_bn: + if use_running_average_bn: mean = self.ra_mean.value var = self.ra_var.value @@ -482,13 +486,13 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag keepdims=True) var = sum_vv / count_v - + if update_batch_norm: self.ra_mean.value = momentum * \ self.ra_mean.value + (1 - momentum) * mean self.ra_var.value = momentum * \ self.ra_var.value + (1 - momentum) * var - + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -519,7 +523,12 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -548,7 +557,10 @@ def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn) + inputs = BatchNorm(config)(inputs, + input_paddings, + update_batch_norm, + use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -588,7 +600,12 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -631,12 +648,12 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs @@ -673,7 +690,11 @@ def __call__(self, # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn) + outputs = ConformerBlock(config)(outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 6c55acfb0..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -108,7 +108,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index cab73df4a..61400806a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -373,7 +373,7 @@ def forward(self, inputs, input_paddings): self.momentum) * mean.detach() self.running_var = (1 - self.momentum) * self.running_var + ( self.momentum) * var.detach() - + else: mean = self.running_mean var = self.running_var From 39132387e411a5e869a98c5b57fbd1e7b0d12194 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 22:37:04 +0000 Subject: [PATCH 04/11] formatting --- .../librispeech_conformer/librispeech_jax/models.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index db92f56d4..a7f786c32 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -616,7 +616,12 @@ def __call__(self, inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average) + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, + use_running_average + ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) From baac0a452871ef5a940c07dbfd64f6d3b9c5427d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 23:26:42 +0000 Subject: [PATCH 05/11] formatting --- .../librispeech_conformer/librispeech_jax/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index a7f786c32..cb6287c5e 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -616,10 +616,10 @@ def __call__(self, inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, - input_paddings, - train, - update_batch_norm, + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, use_running_average ) From 087fd5c1e8400d1ab162cbf79e0fad6a828dae5f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 23:58:10 +0000 Subject: [PATCH 06/11] debugging --- .../workloads/librispeech_conformer/librispeech_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index e362f973b..3caf151ab 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -113,6 +113,7 @@ def model_fn( variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN + print(type(use_running_average_bn)) if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( variables, From c5c36c291f2c2a5a21bc0b60961a7016039e93ae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:30:40 +0000 Subject: [PATCH 07/11] add seperate model_fn for deepspeech jax without use_running_average_bn --- .../librispeech_jax/workload.py | 1 - .../librispeech_jax/workload.py | 31 +++++++++++++++++++ tests/reference_algorithm_tests.py | 1 + 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 3caf151ab..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -113,7 +113,6 @@ def model_fn( variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN - print(type(use_running_average_bn)) if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( variables, diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 8473fac0f..c81b1b0b4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -55,6 +55,37 @@ def init_model_fn( model_state = jax_utils.replicate(model_state) params = jax_utils.replicate(params) return params, model_state + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 74c06e180..5e563d2f9 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -408,6 +408,7 @@ def _test_submission(workload_name, workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], return_class=True) + print(f'Workload class for {workload_name} is {workload_class}') submission_module_path = workloads.convert_filepath_to_module(submission_path) submission_module = importlib.import_module(submission_module_path) From 783aab4a3c2952823290c3e3881b0e423231a2ae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:32:36 +0000 Subject: [PATCH 08/11] fix syntax error --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index c81b1b0b4..05fdf90e7 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -56,7 +56,7 @@ def init_model_fn( params = jax_utils.replicate(params) return params, model_state - def model_fn( + def model_fn( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], From 28e7e21a334001f3e62ce15875f3126af0affbd6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:37:29 +0000 Subject: [PATCH 09/11] fix --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 05fdf90e7..2d46960ed 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,5 @@ import functools -from typing import Optional +from typing import Optional, Dict, Tuple from flax import jax_utils import jax From b063f9f7fa5288736c006ff111454e77869ada8f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:43:44 +0000 Subject: [PATCH 10/11] fix import order --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2d46960ed..e5030f426 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,5 @@ import functools -from typing import Optional, Dict, Tuple +from typing import Dict, Optional, Tuple from flax import jax_utils import jax From 894cd872aa07d97e2a4fe0ce9e5a0dcdb790bfb8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:53:11 +0000 Subject: [PATCH 11/11] formatting --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index e5030f426..a0db6d607 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -55,7 +55,7 @@ def init_model_fn( model_state = jax_utils.replicate(model_state) params = jax_utils.replicate(params) return params, model_state - + def model_fn( self, params: spec.ParameterContainer,