diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 2e5828912..590f500fa 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -57,6 +57,7 @@ def sync_ddp_time(time: float, device: torch.device) -> float: dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) return time_tensor.item() + def update_batch_norm_fn(module: spec.ParameterContainer, update_batch_norm: bool) -> None: bn_layers = ( @@ -75,4 +76,4 @@ def update_batch_norm_fn(module: spec.ParameterContainer, else: module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): - module.momentum = module.momentum_backup \ No newline at end of file + module.momentum = module.momentum_backup diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 09338ca82..059352fb6 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -31,10 +31,10 @@ def __call__(self, update_batch_norm: bool = True, use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - - # Preserve default behavior for backwards compatibility + + # Preserve default behavior for backwards compatibility if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, use_running_average=use_running_average_bn, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 019dde38c..8268c6ca3 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -111,7 +111,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 2e680cbd9..34cd17440 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -88,9 +88,9 @@ def __call__(self, use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - # Preserve default behavior for backwards compatibility + # Preserve default behavior for backwards compatibility if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, use_running_average=use_running_average_bn, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 46168c2a0..2747fc2db 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -149,7 +149,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index 077ff0f89..db92f56d4 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,11 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn): + def __call__(self, + inputs, + input_paddings, + update_batch_norm, + use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +466,7 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if use_running_average_bn: + if use_running_average_bn: mean = self.ra_mean.value var = self.ra_var.value @@ -482,13 +486,13 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag keepdims=True) var = sum_vv / count_v - + if update_batch_norm: self.ra_mean.value = momentum * \ self.ra_mean.value + (1 - momentum) * mean self.ra_var.value = momentum * \ self.ra_var.value + (1 - momentum) * var - + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -519,7 +523,12 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -548,7 +557,10 @@ def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn) + inputs = BatchNorm(config)(inputs, + input_paddings, + update_batch_norm, + use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -588,7 +600,12 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -631,12 +648,12 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs @@ -673,7 +690,11 @@ def __call__(self, # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn) + outputs = ConformerBlock(config)(outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 6c55acfb0..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -108,7 +108,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index cab73df4a..61400806a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -373,7 +373,7 @@ def forward(self, inputs, input_paddings): self.momentum) * mean.detach() self.running_var = (1 - self.momentum) * self.running_var + ( self.momentum) * var.detach() - + else: mean = self.running_mean var = self.running_var