Skip to content

Commit

Permalink
Merge pull request #595 from runame/wmt-variants
Browse files Browse the repository at this point in the history
Add WMT workload variants
  • Loading branch information
priyakasimbeg authored Dec 12, 2023
2 parents 8de3bca + c5e9a03 commit 3ee6338
Show file tree
Hide file tree
Showing 14 changed files with 730 additions and 86 deletions.
58 changes: 43 additions & 15 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,19 @@ class TransformerConfig:
qkv_dim: int = 1024
mlp_dim: int = 1024
max_len: int = 256
activation: Callable = nn.relu
glu: bool = False
#If None, defaults to 0.1.
dropout_rate: Optional[float] = 0.1
#If None, defaults to 0.1.
attention_dropout_rate: Optional[float] = 0.1
attention_temp: float = 1.0
deterministic: bool = False
decode: bool = False
kernel_init: Callable = nn.initializers.xavier_uniform()
bias_init: Callable = nn.initializers.normal(stddev=1e-6)
posemb_init: Optional[Callable] = None
pre_ln: bool = True


def shift_right(x, axis=1):
Expand Down Expand Up @@ -155,7 +159,15 @@ def __call__(self, inputs):
kernel_init=cfg.kernel_init,
bias_init=cfg.bias_init)(
inputs)
x = nn.relu(x)
x = cfg.activation(x)
if cfg.glu:
y = nn.Dense(
cfg.mlp_dim,
dtype=cfg.dtype,
kernel_init=cfg.kernel_init,
bias_init=cfg.bias_init)(
inputs)
x = x * y
if cfg.dropout_rate is None:
dropout_rate = 0.1
else:
Expand Down Expand Up @@ -192,15 +204,16 @@ def __call__(self, inputs, encoder_mask=None):
output after transformer encoder block.
"""
cfg = self.config
pre_ln = cfg.pre_ln

# Attention block.
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
if cfg.attention_dropout_rate is None:
attention_dropout_rate = 0.1
else:
attention_dropout_rate = cfg.attention_dropout_rate
x = nn.SelfAttention(
x = nn.MultiHeadDotProductAttention(
num_heads=cfg.num_heads,
dtype=cfg.dtype,
qkv_features=cfg.qkv_dim,
Expand All @@ -209,20 +222,24 @@ def __call__(self, inputs, encoder_mask=None):
use_bias=False,
broadcast_dropout=False,
dropout_rate=attention_dropout_rate,
deterministic=cfg.deterministic)(x, encoder_mask)
deterministic=cfg.deterministic)(cfg.attention_temp * x,
x,
encoder_mask)

if cfg.dropout_rate is None:
dropout_rate = 0.1
else:
dropout_rate = cfg.dropout_rate
x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
x = x + inputs
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)

# MLP block.
y = nn.LayerNorm(dtype=cfg.dtype)(x)
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
y = MlpBlock(config=cfg)(y)

return x + y
return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)


class EncoderDecoder1DBlock(nn.Module):
Expand Down Expand Up @@ -251,16 +268,17 @@ def __call__(self,
output after transformer encoder-decoder block.
"""
cfg = self.config
pre_ln = cfg.pre_ln

# Decoder block.
assert targets.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(targets)
x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets

if cfg.attention_dropout_rate is None:
attention_dropout_rate = 0.1
else:
attention_dropout_rate = cfg.attention_dropout_rate
x = nn.SelfAttention(
x = nn.MultiHeadDotProductAttention(
num_heads=cfg.num_heads,
dtype=cfg.dtype,
qkv_features=cfg.qkv_dim,
Expand All @@ -270,16 +288,18 @@ def __call__(self,
broadcast_dropout=False,
dropout_rate=attention_dropout_rate,
deterministic=cfg.deterministic,
decode=cfg.decode)(x, decoder_mask)
decode=cfg.decode)(cfg.attention_temp * x, x, decoder_mask)
if cfg.dropout_rate is None:
dropout_rate = 0.1
else:
dropout_rate = cfg.dropout_rate
x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
x = x + targets
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)

# Encoder-Decoder block.
y = nn.LayerNorm(dtype=cfg.dtype)(x)
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
y = nn.MultiHeadDotProductAttention(
num_heads=cfg.num_heads,
dtype=cfg.dtype,
Expand All @@ -289,16 +309,20 @@ def __call__(self,
use_bias=False,
broadcast_dropout=False,
dropout_rate=attention_dropout_rate,
deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask)
deterministic=cfg.deterministic)(cfg.attention_temp * y,
encoded,
encoder_decoder_mask)

y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
y = y + x
if not pre_ln:
y = nn.LayerNorm(dtype=cfg.dtype)(y)

# MLP block.
z = nn.LayerNorm(dtype=cfg.dtype)(y)
z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
z = MlpBlock(config=cfg)(z)

return y + z
return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)


class Encoder(nn.Module):
Expand Down Expand Up @@ -352,7 +376,9 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
x = Encoder1DBlock(
config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask)

encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
encoded = (
nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
if cfg.pre_ln else x)

return encoded

Expand Down Expand Up @@ -423,7 +449,9 @@ def __call__(self,
encoded,
decoder_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask)
y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
y = (
nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
if cfg.pre_ln else y)

# Use the transpose of embedding matrix for logit transform.
logits = output_embed.attend(y.astype(jnp.float32))
Expand Down
72 changes: 69 additions & 3 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""WMT workload implemented in Jax."""

from dataclasses import replace
import functools
from typing import Any, Dict, Iterator, Optional, Tuple

Expand Down Expand Up @@ -215,11 +217,23 @@ def init_model_fn(
input_shape = (init_fake_batch_size, 256)
target_shape = (init_fake_batch_size, 256)

if self.activation == 'relu':
activation = nn.relu
elif self.activation == 'tanh':
activation = jnp.tanh
else:
raise ValueError(f'Unknown activation function {self.activation}.')

model_config = models.TransformerConfig(
dropout_rate=dropout_rate, attention_dropout_rate=aux_dropout_rate)
dropout_rate=dropout_rate,
attention_dropout_rate=aux_dropout_rate,
pre_ln=self.pre_ln,
attention_temp=self.attention_temp,
activation=activation,
glu=self.glu)
self._train_model = models.Transformer(model_config)
self._eval_model = models.Transformer(
models.TransformerConfig(deterministic=True))
eval_config = replace(model_config, deterministic=True)
self._eval_model = models.Transformer(eval_config)
initial_variables = jax.jit(self._eval_model.init)(
rng,
jnp.ones(input_shape, jnp.float32),
Expand Down Expand Up @@ -277,3 +291,55 @@ def _normalize_eval_metrics(
del num_examples
eval_denominator = total_metrics.pop('denominator')
return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics)


class WmtWorkloadPostLN(WmtWorkload):
"""WMT Jax workload with post instead of pre layer norm."""

@property
def validation_target_value(self) -> float:
return 30.2003

@property
def test_target_value(self) -> float:
return 29.8982

@property
def pre_ln(self) -> bool:
return False


class WmtWorkloadAttentionTemp(WmtWorkload):
"""WMT Jax workload with attention temperature = 4.0."""

@property
def validation_target_value(self) -> float:
return 30.0756

@property
def test_target_value(self) -> float:
return 29.8094

@property
def attention_temp(self) -> float:
return 4.0


class WmtWorkloadGLUTanH(WmtWorkload):
"""WMT Jax workload with GLU and TanH activations."""

@property
def validation_target_value(self) -> float:
return 30.0002

@property
def test_target_value(self) -> float:
return 29.8139

@property
def activation(self) -> str:
return 'tanh'

@property
def glu(self) -> bool:
return True
Loading

0 comments on commit 3ee6338

Please sign in to comment.