Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conformer workload variants #590

Merged
merged 35 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c776d26
add variants
priyakasimbeg Nov 13, 2023
a0113e6
conformer variants
priyakasimbeg Nov 14, 2023
4738fcd
add attention temp in fwd of MHSAwithQS fwd
priyakasimbeg Nov 16, 2023
97b9818
Merge branch 'dev' into conformer_workload_variants
priyakasimbeg Nov 16, 2023
d7b8cc6
add conformer variants
priyakasimbeg Nov 16, 2023
7444828
add bsz for conformer variants to baselines
priyakasimbeg Nov 16, 2023
07382c0
add cnoformer variants to docker startup.sh"
priyakasimbeg Nov 16, 2023
c9c195f
variants
priyakasimbeg Nov 16, 2023
c2e9510
add functools to librispeech jax
priyakasimbeg Nov 16, 2023
51f5a5b
fix in conformer workload
priyakasimbeg Nov 16, 2023
a5d1820
fix
priyakasimbeg Nov 16, 2023
86bf643
general variant infra
priyakasimbeg Nov 16, 2023
88bfde9
conformer variant fix
priyakasimbeg Nov 16, 2023
a440604
fix
priyakasimbeg Nov 16, 2023
d44997d
merge
priyakasimbeg Nov 22, 2023
e90bf17
merge
priyakasimbeg Nov 22, 2023
a77c04d
torch compile
priyakasimbeg Nov 22, 2023
8999335
torch compile
priyakasimbeg Nov 23, 2023
4af702e
torch compile
priyakasimbeg Nov 23, 2023
a081191
conformer workload variants
priyakasimbeg Nov 23, 2023
171eece
conformer workload variants
priyakasimbeg Nov 23, 2023
856c9f3
refactor baseworkload logic
priyakasimbeg Nov 23, 2023
b53ec9a
Merge branch 'dev' into conformer_workload_variants
priyakasimbeg Nov 27, 2023
d727522
Merge branch 'conformer_workload_variants' of github.com:mlcommons/al…
priyakasimbeg Nov 27, 2023
e61c1d9
add conformer modeldiff tests
priyakasimbeg Nov 28, 2023
91039e0
use tanh approximation for gelu
priyakasimbeg Nov 28, 2023
4d6746a
refactor gelu workload
priyakasimbeg Nov 28, 2023
722f1c0
style change for gelu workload naming
priyakasimbeg Nov 28, 2023
255a64a
formatting
priyakasimbeg Nov 28, 2023
62527d4
raise not implemented error
priyakasimbeg Nov 28, 2023
25b3fbe
push
priyakasimbeg Nov 28, 2023
6885e69
fix arg in submission_runner
priyakasimbeg Nov 28, 2023
bb435a4
fix
priyakasimbeg Dec 5, 2023
b42c91a
Merge branch 'dev' into conformer_workload_variants
priyakasimbeg Dec 7, 2023
c3d79ca
remove training white space
priyakasimbeg Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
y = layer_norm(x)
"""

import functools
import math
from typing import Any, List, Optional

Expand Down Expand Up @@ -57,6 +58,9 @@ class ConformerConfig:
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
attention_temperature: float = 1.0
activation_function_name: str = 'swish'
use_post_layer_norm: bool = True


class LayerNorm(nn.Module):
Expand Down Expand Up @@ -204,7 +208,15 @@ def __call__(self, inputs, padding_mask=None, train=False):
use_bias=True,
kernel_init=nn.initializers.xavier_uniform())(
inputs)
inputs = nn.swish(inputs)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
activation_fn = nn.gelu
else:
raise ValueError('Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{config.activation_function_name}')
inputs = activation_fn(inputs)
inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)(
inputs, deterministic=not train)

Expand Down Expand Up @@ -300,7 +312,8 @@ def dot_product_attention(query,
dropout_rate=0.,
deterministic=False,
dtype=jnp.float32,
precision=None):
precision=None,
temperature=1.0):
"""Computes dot-product attention given query, key, and value.

This is the core function for applying attention based on
Expand Down Expand Up @@ -360,7 +373,8 @@ def dot_product_attention(query,

# return weighted sum over values for each query position
return jnp.einsum(
'...hqk,...khd->...qhd', attn_weights, value, precision=precision)
'...hqk,...khd->...qhd', attn_weights, value,
precision=precision) * temperature


class MultiHeadedSelfAttention(nn.Module):
Expand All @@ -383,6 +397,8 @@ def __call__(self, inputs, paddings, train):

inputs = LayerNorm(dim=config.encoder_dim)(inputs)

attention_fn = functools.partial(
dot_product_attention, temperature=config.attention_temperature)
result = nn.SelfAttention(
num_heads=config.num_attention_heads,
qkv_features=config.encoder_dim,
Expand All @@ -392,7 +408,7 @@ def __call__(self, inputs, paddings, train):
bias_init=nn.initializers.zeros,
use_bias=True,
broadcast_dropout=False,
attention_fn=dot_product_attention,
attention_fn=attention_fn,
dropout_rate=config.attention_dropout_rate,
deterministic=not train)(inputs, attention_mask)

Expand Down Expand Up @@ -531,8 +547,15 @@ def __call__(self, inputs, input_paddings, train):
inputs)

inputs = BatchNorm(config)(inputs, input_paddings, train)

inputs = nn.swish(inputs)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
activation_fn = nn.gelu
else:
raise ValueError('Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{config.activation_function_name}')
inputs = activation_fn(inputs)
inputs = nn.Dense(
config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())(
inputs)
Expand Down Expand Up @@ -579,7 +602,8 @@ def __call__(self, inputs, input_paddings, train):
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
inputs, padding_mask, train)

inputs = LayerNorm(dim=config.encoder_dim)(inputs)
if config.use_post_layer_norm:
inputs = LayerNorm(dim=config.encoder_dim)(inputs)

return inputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def __init__(self,
self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path)
self.use_specaug = use_specaug

@property
def use_post_layer_norm(self) -> bool:
return True

@property
def use_gelu(self) -> bool:
return False

@property
def attention_temperature(self) -> float:
return 1.0

def init_model_fn(
self,
rng: spec.RandomState,
Expand All @@ -41,11 +53,18 @@ def init_model_fn(
Here we use dropout_rate as *_residual_dropout_rate, and aux_dropout_rate as
input_dropout_rate.
"""
if self.use_gelu:
activation_function_name = 'gelu'
else:
activation_function_name = 'swish'
model_config = models.ConformerConfig(
attention_residual_dropout_rate=dropout_rate,
feed_forward_residual_dropout_rate=dropout_rate,
input_dropout_rate=aux_dropout_rate,
use_specaug=self.use_specaug)
use_specaug=self.use_specaug,
attention_temperature=self.attention_temperature,
use_post_layer_norm=self.use_post_layer_norm,
activation_function_name=activation_function_name)
self._model = models.Conformer(model_config)
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
Expand Down Expand Up @@ -344,3 +363,25 @@ def sync_batch_stats(
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
return new_model_state


class LibriSpeechConformerAttentionTemperatureWorkload(
LibriSpeechConformerWorkload):

@property
def attention_temperature(self) -> float:
return 1.6


class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload):

@property
def use_post_layer_norm(self) -> bool:
return False


class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload):

@property
def use_gelu(self) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from dataclasses import dataclass
from functools import partial
import math
from typing import Tuple

Expand Down Expand Up @@ -42,6 +43,9 @@ class ConformerConfig:
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
attention_temperature: float = 1.0
activation_function_name: str = 'swish'
use_post_layer_norm: bool = True


def initialize(m):
Expand Down Expand Up @@ -207,7 +211,16 @@ def __init__(self, config: ConformerConfig):
def forward(self, inputs, padding_mask):
inputs = self.ln(inputs)
inputs = self.linear1(inputs)
inputs = F.silu(inputs)
if self.config.activation_function_name == 'swish':
activation_fn = F.silu
elif self.config.activation_function_name == 'gelu':
# Use tanh approximation of GELU which is default for jax
activation_fn = partial(F.gelu, approximate='tanh')
else:
raise ValueError('Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{self.config.activation_function_name}')
inputs = activation_fn(inputs)
inputs = self.dropout1(inputs)
inputs = inputs * padding_mask
inputs = self.linear2(inputs)
Expand Down Expand Up @@ -270,6 +283,7 @@ def __init__(self, config: ConformerConfig):
self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim)
self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads)
self.attention_temperature = config.attention_temperature

def forward(self, inputs, key_padding_mask=None):
batch_size, seq_len, embed_dim = inputs.shape
Expand All @@ -284,6 +298,7 @@ def forward(self, inputs, key_padding_mask=None):
attn_mask=~key_padding_mask[:, None, None],
dropout_p=self.dropout,
).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
out = out * self.attention_temperature
out = self.out_proj(out)
return out

Expand Down Expand Up @@ -405,7 +420,15 @@ def forward(self, inputs, input_paddings):
inputs = inputs.permute(0, 2, 1)

inputs = self.bn(inputs, input_paddings)
inputs = F.silu(inputs)
if self.config.activation_function_name == 'swish':
activation_fn = F.silu
elif self.config.activation_function_name == 'gelu':
activation_fn = F.gelu
else:
raise ValueError('Only "swish" and "gelu" are supported '
'config.activation_function_name values, recieved '
f'{self.config.activation_function_name}')
inputs = activation_fn(inputs)
inputs = self.lin3(inputs)

inputs = self.dropout(inputs)
Expand All @@ -421,15 +444,18 @@ def __init__(self, config: ConformerConfig):
self.mhsa = MultiHeadedSelfAttention(config)
self.conv = ConvolutionBlock(config)
self.ff2 = FeedForwardModule(config)
self.ln = LayerNorm(dim=config.encoder_dim)
self.ln = None
if config.use_post_layer_norm:
self.ln = LayerNorm(dim=config.encoder_dim)

def forward(self, inputs, input_paddings):
padding_mask = 1 - input_paddings[:, :, None]
inputs = inputs + 0.5 * self.ff1(inputs, padding_mask)
inputs = inputs + self.mhsa(inputs, input_paddings)
inputs = inputs + self.conv(inputs, input_paddings)
inputs = inputs + 0.5 * self.ff2(inputs, padding_mask)
inputs = self.ln(inputs)
if self.ln:
inputs = self.ln(inputs)
return inputs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from algorithmic_efficiency.workloads.librispeech_conformer.input_pipeline import \
LibriSpeechDataset
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \
models as conformer_model
models

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()

Expand All @@ -36,6 +36,18 @@ def __init__(self,
self.tokenizer = metrics.load_tokenizer(tokenizer_vocab_path)
self.use_specaug = use_specaug

@property
def use_post_layer_norm(self) -> bool:
return True

@property
def use_gelu(self) -> bool:
return False

@property
def attention_temperature(self) -> float:
return 1.0

def init_model_fn(
self,
rng: spec.RandomState,
Expand All @@ -52,15 +64,22 @@ def init_model_fn(
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
model = conformer_model.ConformerEncoderDecoder(
conformer_model.ConformerConfig(
if self.use_gelu:
activation_function_name = 'gelu'
else:
activation_function_name = 'swish'
model = models.ConformerEncoderDecoder(
models.ConformerConfig(
attention_residual_dropout_rate=dropout_rate,
feed_forward_residual_dropout_rate=dropout_rate,
conv_residual_dropout_rate=dropout_rate,
input_dropout_rate=aux_dropout_rate,
use_specaug=self.use_specaug))
use_specaug=self.use_specaug,
attention_temperature=self.attention_temperature,
use_post_layer_norm=self.use_post_layer_norm,
activation_function_name=activation_function_name))
self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none')
conformer_model.initialize(model)
models.initialize(model)
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
Expand Down Expand Up @@ -310,3 +329,25 @@ def _eval_model_on_split(self,
float(total_metrics['word_errors'].item() /
total_metrics['num_words'].item()),
}


class LibriSpeechConformerAttentionTemperatureWorkload(
LibriSpeechConformerWorkload):

@property
def attention_temperature(self) -> float:
return 1.6


class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload):

@property
def use_post_layer_norm(self) -> bool:
return False


class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload):

@property
def use_gelu(self) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'wer'

@property
def use_post_layer_norm(self) -> bool:
raise NotImplementedError

@property
def use_gelu(self) -> bool:
raise NotImplementedError

@property
def attention_temperature(self) -> float:
raise NotImplementedError

def has_reached_validation_target(self, eval_result: Dict[str,
float]) -> bool:
return eval_result['validation/wer'] < self.validation_target_value
Expand Down
14 changes: 14 additions & 0 deletions algorithmic_efficiency/workloads/workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@
'workload_path': 'librispeech_conformer/librispeech',
'workload_class_name': 'LibriSpeechConformerWorkload',
},
'librispeech_conformer_attention_temperature': {
'workload_path':
'librispeech_conformer/librispeech',
'workload_class_name':
'LibriSpeechConformerAttentionTemperatureWorkload',
},
'librispeech_conformer_layernorm': {
'workload_path': 'librispeech_conformer/librispeech',
'workload_class_name': 'LibriSpeechConformerLayerNormWorkload',
},
'librispeech_conformer_gelu': {
'workload_path': 'librispeech_conformer/librispeech',
'workload_class_name': 'LibriSpeechConformerGeluWorkload',
},
'librispeech_deepspeech': {
'workload_path': 'librispeech_deepspeech/librispeech',
'workload_class_name': 'LibriSpeechDeepSpeechWorkload',
Expand Down
4 changes: 3 additions & 1 deletion docker/scripts/startup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \
"wmt" "mnist")
VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_vit" "fastmri" "ogbg" \
"wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist" \
"criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init")
"criteo1tb_resnet" "criteo1tb_layernorm" "criteo_embed_init" \
"conformer_layernorm" "conformer_attention_temperature" \
"conformer_gelu")


# Set data and experiment paths
Expand Down
Loading
Loading