Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Nov 28, 2023
1 parent 722f1c0 commit 255a64a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
y = layer_norm(x)
"""

import functools
import math
from typing import Any, List, Optional

Expand All @@ -20,7 +21,6 @@
import jax
import jax.numpy as jnp
import numpy as np
import functools

from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \
librispeech_preprocessor as preprocessor
Expand Down Expand Up @@ -213,10 +213,9 @@ def __call__(self, inputs, padding_mask=None, train=False):
elif config.activation_function_name == 'gelu':
activation_fn = nn.gelu
else:
raise ValueError(
'Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{config.activation_function_name}')
raise ValueError('Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{config.activation_function_name}')
inputs = activation_fn(inputs)
inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)(
inputs, deterministic=not train)
Expand Down Expand Up @@ -373,8 +372,9 @@ def dot_product_attention(query,
precision)

# return weighted sum over values for each query position
return jnp.einsum('...hqk,...khd->...qhd',
attn_weights, value, precision=precision) * temperature
return jnp.einsum(
'...hqk,...khd->...qhd', attn_weights, value,
precision=precision) * temperature


class MultiHeadedSelfAttention(nn.Module):
Expand Down Expand Up @@ -552,10 +552,9 @@ def __call__(self, inputs, input_paddings, train):
elif config.activation_function_name == 'gelu':
activation_fn = nn.gelu
else:
raise ValueError(
'Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{config.activation_function_name}')
raise ValueError('Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{config.activation_function_name}')
inputs = activation_fn(inputs)
inputs = nn.Dense(
config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,9 @@ def forward(self, inputs, padding_mask):
# Use tanh approximation of GELU which is default for jax
activation_fn = partial(F.gelu, approximate='tanh')
else:
raise ValueError(
'Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{self.config.activation_function_name}')
raise ValueError('Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{self.config.activation_function_name}')
inputs = activation_fn(inputs)
inputs = self.dropout1(inputs)
inputs = inputs * padding_mask
Expand Down Expand Up @@ -426,10 +425,9 @@ def forward(self, inputs, input_paddings):
elif self.config.activation_function_name == 'gelu':
activation_fn = F.gelu
else:
raise ValueError(
'Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{self.config.activation_function_name}')
raise ValueError('Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{self.config.activation_function_name}')
inputs = activation_fn(inputs)
inputs = self.lin3(inputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BaseLibrispeechWorkload(spec.Workload):
def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'wer'

@property
def use_post_layer_norm(self) -> bool:
raise NotImplemented
Expand Down
6 changes: 4 additions & 2 deletions algorithmic_efficiency/workloads/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@
'workload_class_name': 'LibriSpeechConformerWorkload',
},
'librispeech_conformer_attention_temperature': {
'workload_path': 'librispeech_conformer/librispeech',
'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload',
'workload_path':
'librispeech_conformer/librispeech',
'workload_class_name':
'LibriSpeechConformerAttentionTemperatureWorkload',
},
'librispeech_conformer_layernorm': {
'workload_path': 'librispeech_conformer/librispeech',
Expand Down

0 comments on commit 255a64a

Please sign in to comment.