diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index 9bb4411ec..ed05f4335 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -12,6 +12,7 @@ y = layer_norm(x) """ +import functools import math from typing import Any, List, Optional @@ -20,7 +21,6 @@ import jax import jax.numpy as jnp import numpy as np -import functools from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor @@ -213,10 +213,9 @@ def __call__(self, inputs, padding_mask=None, train=False): elif config.activation_function_name == 'gelu': activation_fn = nn.gelu else: - raise ValueError( - 'Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{config.activation_function_name}') + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{config.activation_function_name}') inputs = activation_fn(inputs) inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( inputs, deterministic=not train) @@ -373,8 +372,9 @@ def dot_product_attention(query, precision) # return weighted sum over values for each query position - return jnp.einsum('...hqk,...khd->...qhd', - attn_weights, value, precision=precision) * temperature + return jnp.einsum( + '...hqk,...khd->...qhd', attn_weights, value, + precision=precision) * temperature class MultiHeadedSelfAttention(nn.Module): @@ -552,10 +552,9 @@ def __call__(self, inputs, input_paddings, train): elif config.activation_function_name == 'gelu': activation_fn = nn.gelu else: - raise ValueError( - 'Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{config.activation_function_name}') + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{config.activation_function_name}') inputs = activation_fn(inputs) inputs = nn.Dense( config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())( diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index c6d3cda7f..fe3a1e179 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -217,10 +217,9 @@ def forward(self, inputs, padding_mask): # Use tanh approximation of GELU which is default for jax activation_fn = partial(F.gelu, approximate='tanh') else: - raise ValueError( - 'Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}') inputs = activation_fn(inputs) inputs = self.dropout1(inputs) inputs = inputs * padding_mask @@ -426,10 +425,9 @@ def forward(self, inputs, input_paddings): elif self.config.activation_function_name == 'gelu': activation_fn = F.gelu else: - raise ValueError( - 'Only "swish" and "gelu" are supported ' - 'config.activation_function_name values, recieved ' - f'{self.config.activation_function_name}') + raise ValueError('Only "swish" and "gelu" are supported ' + 'config.activation_function_name values, recieved ' + f'{self.config.activation_function_name}') inputs = activation_fn(inputs) inputs = self.lin3(inputs) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index 8e0fb916f..7a7d5498c 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -12,7 +12,7 @@ class BaseLibrispeechWorkload(spec.Workload): def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'wer' - + @property def use_post_layer_norm(self) -> bool: raise NotImplemented diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index a972700d3..679cc75d7 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -49,8 +49,10 @@ 'workload_class_name': 'LibriSpeechConformerWorkload', }, 'librispeech_conformer_attention_temperature': { - 'workload_path': 'librispeech_conformer/librispeech', - 'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload', + 'workload_path': + 'librispeech_conformer/librispeech', + 'workload_class_name': + 'LibriSpeechConformerAttentionTemperatureWorkload', }, 'librispeech_conformer_layernorm': { 'workload_path': 'librispeech_conformer/librispeech',