From cae33e25c87c3fc7a012b19d269f48a90404c4e5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 20 Dec 2023 01:06:06 +0000 Subject: [PATCH 01/14] modify jax model to add variants --- .../librispeech_jax/models.py | 45 +++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index 769e8c496..718ee4c23 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -58,7 +58,8 @@ class DeepspeechConfig: enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True - + use_tanh: bool = False + layernorm_everywhere: bool = False class Subsample(nn.Module): """Module to perform strided convolution in order to subsample inputs. @@ -80,7 +81,9 @@ def __call__(self, inputs, output_paddings, train): batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon, input_channels=1, - output_channels=config.encoder_dim)(outputs, output_paddings, train) + output_channels=config.encoder_dim, + use_tanh=config.use_tanh + )(outputs, output_paddings, train) outputs, output_paddings = Conv2dSubsampling( encoder_dim=config.encoder_dim, @@ -88,7 +91,8 @@ def __call__(self, inputs, output_paddings, train): batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon, input_channels=config.encoder_dim, - output_channels=config.encoder_dim)(outputs, output_paddings, train) + output_channels=config.encoder_dim, + use_tanh=config.use_tanh)(outputs, output_paddings, train) batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape @@ -127,6 +131,7 @@ class Conv2dSubsampling(nn.Module): dtype: Any = jnp.float32 batch_norm_momentum: float = 0.999 batch_norm_epsilon: float = 0.001 + use_tanh: bool = False def setup(self): self.filter_shape = (3, 3, self.input_channels, self.output_channels) @@ -150,7 +155,12 @@ def __call__(self, inputs, paddings, train): feature_group_count=feature_group_count) outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) - outputs = nn.relu(outputs) + + if self.use_tanh: + outputs = nn.tanh(outputs) + else: + outputs = nn.relu(outputs) + # Computing correct paddings post input convolution. input_length = paddings.shape[1] @@ -182,16 +192,22 @@ def __call__(self, inputs, input_paddings=None, train=False): padding_mask = jnp.expand_dims(1 - input_paddings, -1) config = self.config - inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, input_paddings, train) + if config.layernorm_everywhere: + inputs = LayerNorm(config.encoder_dim)(inputs) + else: + inputs = BatchNorm(config.encoder_dim, + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon)(inputs, input_paddings, train) inputs = nn.Dense( config.encoder_dim, use_bias=True, kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = nn.relu(inputs) + if config.use_tanh: + inputs = nn.tanh(inputs) + else: + inputs = nn.relu(inputs) inputs *= padding_mask if config.feed_forward_dropout_rate is None: @@ -416,10 +432,13 @@ class BatchRNN(nn.Module): def __call__(self, inputs, input_paddings, train): config = self.config - inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, input_paddings, train) + if config.layernorm_everywhere: + inputs = LayerNorm(config.encoder_dim)(inputs) + else: + inputs = BatchNorm(config.encoder_dim, + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon)(inputs, input_paddings, train) output = CudnnLSTM( features=config.encoder_dim // 2, bidirectional=config.bidirectional, From b17fc26a1f1c14ef001e45663dfc28ce8fa01a33 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 20 Dec 2023 01:44:20 +0000 Subject: [PATCH 02/14] add jax deepspeech workload variants --- .../librispeech_jax/workload.py | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index ac6005225..aebcd50b0 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -29,7 +29,14 @@ def init_model_fn( model_config = models.DeepspeechConfig( feed_forward_dropout_rate=dropout_rate, use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate) + input_dropout_rate=aux_dropout_rate, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count, + ) self._model = models.Deepspeech(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] @@ -67,3 +74,60 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: return 55_506 # ~15.4 hours + + @property + def use_tanh(self) -> bool: + return False + + @property + def enable_residual_connections(self) -> bool: + return True + + @property + def enable_decoder_layer_norm(self) -> bool: + return True + + @property + def layernorm_everywhere(self) -> bool: + return False + + @property + def freq_mask_count(self) -> int: + return 2 + + @property + def time_mask_count(self) -> int: + return 10 + + +class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechConformerWorkload): + + @property + def use_tanh(self) -> bool: + return True + + +class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechConformerWorkload): + + @property + def enable_residual_connections(self) -> bool: + return False + + +class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechConformerWorkload): + + @property + def enable_decoder_layer_norm(self) -> bool: + return False + + @property + def layernorm_everywhere(self) -> bool: + return True + + @property + def freq_mask_count(self) -> int: + return 4 + + @property + def time_mask_count(self) -> int: + return 15 From b0421677db9ed5057c37c0c1ef6d41c5c72618a8 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 20 Dec 2023 02:03:01 +0000 Subject: [PATCH 03/14] add tanh variant --- .../librispeech_jax/models.py | 2 +- .../librispeech_pytorch/models.py | 23 +++++++++++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index 718ee4c23..db8416a5e 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -158,7 +158,7 @@ def __call__(self, inputs, paddings, train): if self.use_tanh: outputs = nn.tanh(outputs) - else: + else: outputs = nn.relu(outputs) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index db7bdd7d1..b2cc1a088 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -45,6 +45,8 @@ class DeepspeechConfig: enable_residual_connections: bool = True enable_decoder_layer_norm: bool = True bidirectional: bool = True + use_tanh: bool = False + layernorm_everywhere = False class LayerNorm(nn.Module): @@ -77,9 +79,9 @@ def __init__(self, config: DeepspeechConfig): self.encoder_dim = encoder_dim self.conv1 = Conv2dSubsampling( - input_channels=1, output_channels=encoder_dim) + input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh) self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, output_channels=encoder_dim) + input_channels=encoder_dim, output_channels=encoder_dim, use_tanh=config.use_tanh) self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) @@ -115,7 +117,8 @@ def __init__(self, filter_stride: Tuple[int] = (2, 2), padding: str = 'SAME', batch_norm_momentum: float = 0.999, - batch_norm_epsilon: float = 0.001): + batch_norm_epsilon: float = 0.001, + use_tanh: bool = False): super().__init__() self.input_channels = input_channels @@ -129,6 +132,8 @@ def __init__(self, nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) self.bias = nn.Parameter(torch.zeros(output_channels)) + self.use_tanh = use_tanh + def get_same_padding(self, input_shape): in_height, in_width = input_shape[2:] stride_height, stride_width = self.filter_stride @@ -162,7 +167,10 @@ def forward(self, inputs, paddings): dilation=(1, 1), groups=groups) - outputs = F.relu(outputs) + if self.use_tanh: + outputs = F.tanh(outputs) + else: + outputs = F.relu(outputs) input_length = paddings.shape[1] stride = self.filter_stride[0] @@ -202,7 +210,12 @@ def forward(self, inputs, input_paddings): padding_mask = (1 - input_paddings)[:, :, None] inputs = self.bn(inputs, input_paddings) inputs = self.lin(inputs) - inputs = F.relu(inputs) + + if self.config.use_tanh: + inputs = F.tanh(inputs) + else: + inputs = F.relu(inputs) + inputs = inputs * padding_mask inputs = self.dropout(inputs) From 83f079bd6a5f885c6a22dece41a9de7c4d3dd179 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 20 Dec 2023 02:21:24 +0000 Subject: [PATCH 04/14] add deepspeech pytorch variants --- .../librispeech_pytorch/models.py | 28 ++++---- .../librispeech_pytorch/workload.py | 65 ++++++++++++++++++- 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index b2cc1a088..594d1c921 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -46,7 +46,7 @@ class DeepspeechConfig: enable_decoder_layer_norm: bool = True bidirectional: bool = True use_tanh: bool = False - layernorm_everywhere = False + layernorm_everywhere: bool = False class LayerNorm(nn.Module): @@ -195,10 +195,13 @@ def __init__(self, config: DeepspeechConfig): super().__init__() self.config = config - self.bn = BatchNorm( - dim=config.encoder_dim, - batch_norm_momentum=config.batch_norm_momentum, - batch_norm_epsilon=config.batch_norm_epsilon) + if self.config.layernorm_everywhere: + self.normalization_layer = LayerNorm(config.encoder_dim) + else: + self.normalization_layer = BatchNorm( + dim=config.encoder_dim, + batch_norm_momentum=config.batch_norm_momentum, + batch_norm_epsilon=config.batch_norm_epsilon) self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True) if config.feed_forward_dropout_rate is None: feed_forward_dropout_rate = 0.1 @@ -208,7 +211,7 @@ def __init__(self, config: DeepspeechConfig): def forward(self, inputs, input_paddings): padding_mask = (1 - input_paddings)[:, :, None] - inputs = self.bn(inputs, input_paddings) + inputs = self.normalization_layer(inputs, input_paddings) inputs = self.lin(inputs) if self.config.use_tanh: @@ -278,9 +281,12 @@ def __init__(self, config: DeepspeechConfig): bidirectional = config.bidirectional self.bidirectional = bidirectional - self.bn = BatchNorm(config.encoder_dim, - config.batch_norm_momentum, - config.batch_norm_epsilon) + if config.layernorm_everywhere: + self.normalization_layer = nn.LayerNorm(config.encoder_dim) + else: + self.normalization_layer = BatchNorm(config.encoder_dim, + config.batch_norm_momentum, + config.batch_norm_epsilon) if bidirectional: self.lstm = nn.LSTM( @@ -293,7 +299,7 @@ def __init__(self, config: DeepspeechConfig): input_size=input_size, hidden_size=hidden_size, batch_first=True) def forward(self, inputs, input_paddings): - inputs = self.bn(inputs, input_paddings) + inputs = self.normalization_layer(inputs, input_paddings) lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( inputs, lengths, batch_first=True, enforce_sorted=False) @@ -342,7 +348,7 @@ def __init__(self, config: DeepspeechConfig): [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) if config.enable_decoder_layer_norm: - self.ln = LayerNorm(config.encoder_dim) + self.ln = nn.LayerNorm(config.encoder_dim) else: self.ln = nn.Identity() diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index bcdd78fb5..9202c5fc0 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -37,7 +37,13 @@ def init_model_fn( DeepspeechConfig( feed_forward_dropout_rate=dropout_rate, use_specaug=self.use_specaug, - input_dropout_rate=aux_dropout_rate)).eval() + input_dropout_rate=aux_dropout_rate, + use_tanh=self.use_tanh, + enable_residual_connections=self.enable_residual_connections, + enable_decoder_layer_norm=self.enable_decoder_layer_norm, + layernorm_everywhere=self.layernorm_everywhere, + freq_mask_count=self.freq_mask_count, + time_mask_count=self.time_mask_count)).eval() self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') # Run model once to initialize lazy layers. t = MAX_INPUT_LENGTH @@ -76,3 +82,60 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: return 55_506 # ~15.4 hours + + @property + def use_tanh(self) -> bool: + return False + + @property + def enable_residual_connections(self) -> bool: + return True + + @property + def enable_decoder_layer_norm(self) -> bool: + return True + + @property + def layernorm_everywhere(self) -> bool: + return False + + @property + def freq_mask_count(self) -> int: + return 2 + + @property + def time_mask_count(self) -> int: + return 10 + + +class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechConformerWorkload): + + @property + def use_tanh(self) -> bool: + return True + + +class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechConformerWorkload): + + @property + def enable_residual_connections(self) -> bool: + return False + + +class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechConformerWorkload): + + @property + def enable_decoder_layer_norm(self) -> bool: + return False + + @property + def layernorm_everywhere(self) -> bool: + return True + + @property + def freq_mask_count(self) -> int: + return 4 + + @property + def time_mask_count(self) -> int: + return 15 \ No newline at end of file From 7e2ca31ef61f9e5649e9edad14539b473ecfb049 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 20 Dec 2023 02:38:33 +0000 Subject: [PATCH 05/14] fix deepspeech batchnorm layer --- .../librispeech_pytorch/models.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 594d1c921..3c964e5cf 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -195,7 +195,7 @@ def __init__(self, config: DeepspeechConfig): super().__init__() self.config = config - if self.config.layernorm_everywhere: + if config.layernorm_everywhere: self.normalization_layer = LayerNorm(config.encoder_dim) else: self.normalization_layer = BatchNorm( @@ -211,7 +211,11 @@ def __init__(self, config: DeepspeechConfig): def forward(self, inputs, input_paddings): padding_mask = (1 - input_paddings)[:, :, None] - inputs = self.normalization_layer(inputs, input_paddings) + if self.config.layernorm_everywhere: + inputs = self.normalization_layer(inputs) + else: # batchnorm + inputs = self.normalization_layer(inputs, input_paddings) + inputs = self.lin(inputs) if self.config.use_tanh: @@ -299,7 +303,10 @@ def __init__(self, config: DeepspeechConfig): input_size=input_size, hidden_size=hidden_size, batch_first=True) def forward(self, inputs, input_paddings): - inputs = self.normalization_layer(inputs, input_paddings) + if self.config.layernorm_everywhere: + inputs = self.normalization_layer(inputs) + else: + inputs = self.normalization_layer(inputs, input_paddings) lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( inputs, lengths, batch_first=True, enforce_sorted=False) From 50ff71f2bf79495225be84c8fcef92e413ed2e9e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 20 Dec 2023 23:18:01 +0000 Subject: [PATCH 06/14] add deepspeech workoad variant names to docker script --- algorithmic_efficiency/workloads/workloads.py | 12 ++++++++++++ docker/scripts/startup.sh | 3 ++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 6d0b08cef..875504bc6 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -90,6 +90,18 @@ 'workload_path': 'librispeech_deepspeech/librispeech', 'workload_class_name': 'LibriSpeechDeepSpeechWorkload', }, + 'librispeech_deepspeech_tanh': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload', + }, + 'librispeech_deepspeech_no_resnet': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload', + }, + 'librispeech_deepspeech_norm_and_spec_aug': { + 'workload_path': 'librispeech_deepspeech/librispeech', + 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload', + }, 'mnist': { 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload' }, diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index be14ab498..3596a2990 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -119,7 +119,8 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \ "conformer_layernorm" "conformer_attention_temperature" \ "conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ - "fastmri_layernorm") + "fastmri_layernorm" "librispeech_deepspeech_tanh" \ + "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug") # Set data and experiment paths ROOT_DATA_BUCKET="gs://mlcommons-data" From 34996e9386572df77adffbe2b1912536145a7a5e Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 6 Feb 2024 04:18:28 -0500 Subject: [PATCH 07/14] minor --- .../workloads/librispeech_conformer/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index c2413c076..9ab3ac488 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -67,7 +67,7 @@ def num_test_examples(self) -> int: @property def eval_batch_size(self) -> int: - return 256 + return 128 @property def train_mean(self): From 72be29679cb824390a952c27c6fa46e7305132ef Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 6 Feb 2024 04:21:33 -0500 Subject: [PATCH 08/14] Revert batch size --- .../workloads/librispeech_conformer/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index 9ab3ac488..c2413c076 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -67,7 +67,7 @@ def num_test_examples(self) -> int: @property def eval_batch_size(self) -> int: - return 128 + return 256 @property def train_mean(self): From daef103dbb65c0659ddfe1c53afac40ad2d4b6cb Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 6 Feb 2024 04:25:56 -0500 Subject: [PATCH 09/14] minor --- .../librispeech_pytorch/workload.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 18853d19c..c9208a793 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -334,6 +334,9 @@ def _eval_model_on_split(self, class LibriSpeechConformerAttentionTemperatureWorkload( LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def attention_temperature(self) -> float: return 1.6 @@ -341,6 +344,9 @@ def attention_temperature(self) -> float: class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def use_post_layer_norm(self) -> bool: return False @@ -348,6 +354,9 @@ def use_post_layer_norm(self) -> bool: class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def use_gelu(self) -> bool: return True From 3a7876565aa525da34e4e10ae23dcfc4bd5e3e76 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 6 Feb 2024 04:30:23 -0500 Subject: [PATCH 10/14] minor --- .../librispeech_pytorch/workload.py | 9 --------- .../librispeech_deepspeech/librispeech_jax/workload.py | 9 +++++++++ .../librispeech_pytorch/workload.py | 9 +++++++++ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index c9208a793..18853d19c 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -334,9 +334,6 @@ def _eval_model_on_split(self, class LibriSpeechConformerAttentionTemperatureWorkload( LibriSpeechConformerWorkload): - def eval_batch_size(self) -> int: - return 128 - @property def attention_temperature(self) -> float: return 1.6 @@ -344,9 +341,6 @@ def attention_temperature(self) -> float: class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): - def eval_batch_size(self) -> int: - return 128 - @property def use_post_layer_norm(self) -> bool: return False @@ -354,9 +348,6 @@ def use_post_layer_norm(self) -> bool: class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): - def eval_batch_size(self) -> int: - return 128 - @property def use_gelu(self) -> bool: return True diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index aebcd50b0..c1e2e71b4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -102,6 +102,9 @@ def time_mask_count(self) -> int: class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def use_tanh(self) -> bool: return True @@ -109,6 +112,9 @@ def use_tanh(self) -> bool: class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def enable_residual_connections(self) -> bool: return False @@ -116,6 +122,9 @@ def enable_residual_connections(self) -> bool: class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def enable_decoder_layer_norm(self) -> bool: return False diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 9202c5fc0..932d38d62 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -110,6 +110,9 @@ def time_mask_count(self) -> int: class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def use_tanh(self) -> bool: return True @@ -117,6 +120,9 @@ def use_tanh(self) -> bool: class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def enable_residual_connections(self) -> bool: return False @@ -124,6 +130,9 @@ def enable_residual_connections(self) -> bool: class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechConformerWorkload): + def eval_batch_size(self) -> int: + return 128 + @property def enable_decoder_layer_norm(self) -> bool: return False From d718663e03682e0ff6f90521460d055c0c415fbc Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 6 Feb 2024 04:35:05 -0500 Subject: [PATCH 11/14] minor --- .../librispeech_pytorch/workload.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 932d38d62..d3eaf2e73 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -108,30 +108,21 @@ def time_mask_count(self) -> int: return 10 -class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechConformerWorkload): - - def eval_batch_size(self) -> int: - return 128 +class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload): @property def use_tanh(self) -> bool: return True -class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechConformerWorkload): - - def eval_batch_size(self) -> int: - return 128 +class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): @property def enable_residual_connections(self) -> bool: return False -class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechConformerWorkload): - - def eval_batch_size(self) -> int: - return 128 +class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload): @property def enable_decoder_layer_norm(self) -> bool: From 4bc8fece8786b31fba12bbc87998c2b10c93ef8e Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 6 Feb 2024 04:48:32 -0500 Subject: [PATCH 12/14] minor --- .../librispeech_jax/workload.py | 15 +++------------ .../librispeech_pytorch/workload.py | 4 ++++ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index c1e2e71b4..036214901 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -100,30 +100,21 @@ def time_mask_count(self) -> int: return 10 -class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechConformerWorkload): - - def eval_batch_size(self) -> int: - return 128 +class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload): @property def use_tanh(self) -> bool: return True -class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechConformerWorkload): - - def eval_batch_size(self) -> int: - return 128 +class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): @property def enable_residual_connections(self) -> bool: return False -class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechConformerWorkload): - - def eval_batch_size(self) -> int: - return 128 +class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload): @property def enable_decoder_layer_norm(self) -> bool: diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index d3eaf2e73..5f20aa685 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -124,6 +124,10 @@ def enable_residual_connections(self) -> bool: class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload): + @property + def eval_batch_size(self) -> int: + return 128 + @property def enable_decoder_layer_norm(self) -> bool: return False From 572cebf65220745bea8f6deef02ac1fc365fe331 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 6 Feb 2024 17:10:12 -0500 Subject: [PATCH 13/14] Lint fix --- .../librispeech_jax/models.py | 18 +++++++++++------- .../librispeech_jax/workload.py | 7 ++++--- .../librispeech_pytorch/models.py | 16 +++++++++------- .../librispeech_pytorch/workload.py | 7 ++++--- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index db8416a5e..f9eb732e9 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -61,6 +61,7 @@ class DeepspeechConfig: use_tanh: bool = False layernorm_everywhere: bool = False + class Subsample(nn.Module): """Module to perform strided convolution in order to subsample inputs. @@ -161,7 +162,6 @@ def __call__(self, inputs, paddings, train): else: outputs = nn.relu(outputs) - # Computing correct paddings post input convolution. input_length = paddings.shape[1] stride = self.filter_stride[0] @@ -196,9 +196,11 @@ def __call__(self, inputs, input_paddings=None, train=False): inputs = LayerNorm(config.encoder_dim)(inputs) else: inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, input_paddings, train) + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon)(inputs, + input_paddings, + train) inputs = nn.Dense( config.encoder_dim, use_bias=True, @@ -436,9 +438,11 @@ def __call__(self, inputs, input_paddings, train): inputs = LayerNorm(config.encoder_dim)(inputs) else: inputs = BatchNorm(config.encoder_dim, - config.dtype, - config.batch_norm_momentum, - config.batch_norm_epsilon)(inputs, input_paddings, train) + config.dtype, + config.batch_norm_momentum, + config.batch_norm_epsilon)(inputs, + input_paddings, + train) output = CudnnLSTM( features=config.encoder_dim // 2, bidirectional=config.bidirectional, diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2867f94f2..b578d4598 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -36,7 +36,7 @@ def init_model_fn( layernorm_everywhere=self.layernorm_everywhere, freq_mask_count=self.freq_mask_count, time_mask_count=self.time_mask_count, - ) + ) self._model = models.Deepspeech(model_config) input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] @@ -83,7 +83,7 @@ def use_tanh(self) -> bool: def enable_residual_connections(self) -> bool: return True - @property + @property def enable_decoder_layer_norm(self) -> bool: return True @@ -114,7 +114,8 @@ def enable_residual_connections(self) -> bool: return False -class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload): +class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload + ): @property def eval_batch_size(self) -> int: diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 3c964e5cf..d270df236 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -81,7 +81,9 @@ def __init__(self, config: DeepspeechConfig): self.conv1 = Conv2dSubsampling( input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh) self.conv2 = Conv2dSubsampling( - input_channels=encoder_dim, output_channels=encoder_dim, use_tanh=config.use_tanh) + input_channels=encoder_dim, + output_channels=encoder_dim, + use_tanh=config.use_tanh) self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True) @@ -213,16 +215,16 @@ def forward(self, inputs, input_paddings): padding_mask = (1 - input_paddings)[:, :, None] if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) - else: # batchnorm + else: # batchnorm inputs = self.normalization_layer(inputs, input_paddings) - + inputs = self.lin(inputs) - + if self.config.use_tanh: inputs = F.tanh(inputs) else: inputs = F.relu(inputs) - + inputs = inputs * padding_mask inputs = self.dropout(inputs) @@ -289,8 +291,8 @@ def __init__(self, config: DeepspeechConfig): self.normalization_layer = nn.LayerNorm(config.encoder_dim) else: self.normalization_layer = BatchNorm(config.encoder_dim, - config.batch_norm_momentum, - config.batch_norm_epsilon) + config.batch_norm_momentum, + config.batch_norm_epsilon) if bidirectional: self.lstm = nn.LSTM( diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 5f20aa685..23d533aa1 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -91,7 +91,7 @@ def use_tanh(self) -> bool: def enable_residual_connections(self) -> bool: return True - @property + @property def enable_decoder_layer_norm(self) -> bool: return True @@ -122,7 +122,8 @@ def enable_residual_connections(self) -> bool: return False -class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload): +class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload + ): @property def eval_batch_size(self) -> int: @@ -142,4 +143,4 @@ def freq_mask_count(self) -> int: @property def time_mask_count(self) -> int: - return 15 \ No newline at end of file + return 15 From f660a215a78a69bdad47fd20bca44a08decd6936 Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Wed, 14 Feb 2024 00:48:15 +0000 Subject: [PATCH 14/14] deepspeech modeldiffs --- .../librispeech_pytorch/models.py | 16 +++--- .../__init__.py | 0 .../compare.py | 53 +++++++++++++++++++ .../__init__.py | 0 .../librispeech_deepspeech_normaug/compare.py | 53 +++++++++++++++++++ .../librispeech_deepspeech_tanh/__init__.py | 0 .../librispeech_deepspeech_tanh/compare.py | 53 +++++++++++++++++++ 7 files changed, 167 insertions(+), 8 deletions(-) create mode 100644 tests/modeldiffs/librispeech_deepspeech_noresnet/__init__.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_normaug/__init__.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_normaug/compare.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_tanh/__init__.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_tanh/compare.py diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index d270df236..a5ee3fa0a 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -200,7 +200,7 @@ def __init__(self, config: DeepspeechConfig): if config.layernorm_everywhere: self.normalization_layer = LayerNorm(config.encoder_dim) else: - self.normalization_layer = BatchNorm( + self.bn_normalization_layer = BatchNorm( dim=config.encoder_dim, batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon) @@ -216,7 +216,7 @@ def forward(self, inputs, input_paddings): if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) else: # batchnorm - inputs = self.normalization_layer(inputs, input_paddings) + inputs = self.bn_normalization_layer(inputs, input_paddings) inputs = self.lin(inputs) @@ -288,11 +288,11 @@ def __init__(self, config: DeepspeechConfig): self.bidirectional = bidirectional if config.layernorm_everywhere: - self.normalization_layer = nn.LayerNorm(config.encoder_dim) + self.normalization_layer = LayerNorm(config.encoder_dim) else: - self.normalization_layer = BatchNorm(config.encoder_dim, - config.batch_norm_momentum, - config.batch_norm_epsilon) + self.bn_normalization_layer = BatchNorm(config.encoder_dim, + config.batch_norm_momentum, + config.batch_norm_epsilon) if bidirectional: self.lstm = nn.LSTM( @@ -308,7 +308,7 @@ def forward(self, inputs, input_paddings): if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) else: - inputs = self.normalization_layer(inputs, input_paddings) + inputs = self.bn_normalization_layer(inputs, input_paddings) lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( inputs, lengths, batch_first=True, enforce_sorted=False) @@ -357,7 +357,7 @@ def __init__(self, config: DeepspeechConfig): [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) if config.enable_decoder_layer_norm: - self.ln = nn.LayerNorm(config.encoder_dim) + self.ln = LayerNorm(config.encoder_dim) else: self.ln = nn.Identity() diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/__init__.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py new file mode 100644 index 000000000..6c00bdf69 --- /dev/null +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -0,0 +1,53 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ + LibriSpeechDeepSpeechTanhWorkload as JaxWorkload +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ + LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.librispeech_deepspeech.compare import key_transform +from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + # Test outputs for identical weights and inputs. + wave = torch.randn(2, 320000) + pad = torch.zeros_like(wave) + pad[0, 200000:] = 1 + + jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} + pyt_batch = {'inputs': (wave, pad)} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] * + (1 - out_outpad[1][:, :, None])) diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/__init__.py b/tests/modeldiffs/librispeech_deepspeech_normaug/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py new file mode 100644 index 000000000..c68d6adf9 --- /dev/null +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -0,0 +1,53 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ + LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ + LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.librispeech_deepspeech.compare import key_transform +from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + # Test outputs for identical weights and inputs. + wave = torch.randn(2, 320000) + pad = torch.zeros_like(wave) + pad[0, 200000:] = 1 + + jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} + pyt_batch = {'inputs': (wave, pad)} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] * + (1 - out_outpad[1][:, :, None])) diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/__init__.py b/tests/modeldiffs/librispeech_deepspeech_tanh/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py new file mode 100644 index 000000000..4cfdf4f21 --- /dev/null +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -0,0 +1,53 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ + LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ + LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.librispeech_deepspeech.compare import key_transform +from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + # Test outputs for identical weights and inputs. + wave = torch.randn(2, 320000) + pad = torch.zeros_like(wave) + pad[0, 200000:] = 1 + + jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} + pyt_batch = {'inputs': (wave, pad)} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] * + (1 - out_outpad[1][:, :, None]))