diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index 16eb8bc2d..9bb4411ec 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -60,7 +60,7 @@ class ConformerConfig: use_specaug: bool = True attention_temperature: float = 1.0 activation_function_name: str = 'swish' - use_post_layer_norm: bool = False + use_post_layer_norm: bool = True class LayerNorm(nn.Module): diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 357c9d848..517a9c4f1 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -377,7 +377,7 @@ class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): property def use_post_layer_norm(self) -> bool: - return True + return False class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 49c637bef..0787403a9 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -44,7 +44,7 @@ class ConformerConfig: use_specaug: bool = True attention_temperature: float = 1.0 activation_function_name: str = 'swish' - use_post_layer_norm: bool = False + use_post_layer_norm: bool = True def initialize(m): diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index eef5a5472..a0a6cfcb7 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -79,7 +79,7 @@ def init_model_fn( use_post_layer_norm=self.use_post_layer_norm, activation_function_name=activation_function_name)) self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') - model.initialize(model) + models.initialize(model) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -343,7 +343,7 @@ class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): property def use_post_layer_norm(self) -> bool: - return True + return False class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload):