Skip to content

Commit

Permalink
Rename norm_first to pre_ln
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 12, 2023
1 parent 76d0f44 commit a026ce7
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 41 deletions.
30 changes: 15 additions & 15 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TransformerConfig:
kernel_init: Callable = nn.initializers.xavier_uniform()
bias_init: Callable = nn.initializers.normal(stddev=1e-6)
posemb_init: Optional[Callable] = None
norm_first: bool = True
pre_ln: bool = True


def shift_right(x, axis=1):
Expand Down Expand Up @@ -204,11 +204,11 @@ def __call__(self, inputs, encoder_mask=None):
output after transformer encoder block.
"""
cfg = self.config
norm_first = cfg.norm_first
pre_ln = cfg.pre_ln

# Attention block.
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if norm_first else inputs
x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
if cfg.attention_dropout_rate is None:
attention_dropout_rate = 0.1
else:
Expand All @@ -232,14 +232,14 @@ def __call__(self, inputs, encoder_mask=None):
dropout_rate = cfg.dropout_rate
x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
x = x + inputs
if not norm_first:
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)

# MLP block.
y = nn.LayerNorm(dtype=cfg.dtype)(x) if norm_first else x
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
y = MlpBlock(config=cfg)(y)

return x + y if norm_first else nn.LayerNorm(dtype=cfg.dtype)(x + y)
return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)


class EncoderDecoder1DBlock(nn.Module):
Expand Down Expand Up @@ -268,11 +268,11 @@ def __call__(self,
output after transformer encoder-decoder block.
"""
cfg = self.config
norm_first = cfg.norm_first
pre_ln = cfg.pre_ln

# Decoder block.
assert targets.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(targets) if norm_first else targets
x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets

if cfg.attention_dropout_rate is None:
attention_dropout_rate = 0.1
Expand All @@ -295,11 +295,11 @@ def __call__(self,
dropout_rate = cfg.dropout_rate
x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
x = x + targets
if not norm_first:
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)

# Encoder-Decoder block.
y = nn.LayerNorm(dtype=cfg.dtype)(x) if norm_first else x
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
y = nn.MultiHeadDotProductAttention(
num_heads=cfg.num_heads,
dtype=cfg.dtype,
Expand All @@ -315,14 +315,14 @@ def __call__(self,

y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
y = y + x
if not norm_first:
if not pre_ln:
y = nn.LayerNorm(dtype=cfg.dtype)(y)

# MLP block.
z = nn.LayerNorm(dtype=cfg.dtype)(y) if norm_first else y
z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
z = MlpBlock(config=cfg)(z)

return y + z if norm_first else nn.LayerNorm(dtype=cfg.dtype)(y + z)
return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)


class Encoder(nn.Module):
Expand Down Expand Up @@ -378,7 +378,7 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None):

encoded = (
nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
if cfg.norm_first else x)
if cfg.pre_ln else x)

return encoded

Expand Down Expand Up @@ -451,7 +451,7 @@ def __call__(self,
encoder_decoder_mask=encoder_decoder_mask)
y = (
nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
if cfg.norm_first else y)
if cfg.pre_ln else y)

# Use the transpose of embedding matrix for logit transform.
logits = output_embed.attend(y.astype(jnp.float32))
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def init_model_fn(
model_config = models.TransformerConfig(
dropout_rate=dropout_rate,
attention_dropout_rate=aux_dropout_rate,
norm_first=self.norm_first,
pre_ln=self.pre_ln,
attention_temp=self.attention_temp,
activation=activation,
glu=self.glu)
Expand Down Expand Up @@ -305,7 +305,7 @@ def test_target_value(self) -> float:
return 29.8982

@property
def norm_first(self) -> bool:
def pre_ln(self) -> bool:
return False


Expand Down
42 changes: 21 additions & 21 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self,
glu: bool = False,
layer_norm_eps: float = 1e-6,
attention_temp: float = 1.0,
norm_first: bool = True):
pre_ln: bool = True):
super().__init__()
if dropout_rate is None:
dropout_rate = 0.1
Expand All @@ -128,7 +128,7 @@ def __init__(self,
glu,
layer_norm_eps,
attention_temp,
norm_first)
pre_ln)
self.decoder = Decoder(d_model,
nhead,
d_hid,
Expand All @@ -139,7 +139,7 @@ def __init__(self,
glu,
layer_norm_eps,
attention_temp,
norm_first)
pre_ln)
# Share positional encoding and embedding between encoder and decoder.
self.encoder.pos_encoder = self.pos_encoder
self.encoder.shared_embedding = self.shared_embedding
Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(self,
glu: bool = False,
layer_norm_eps: float = 1e-6,
attention_temp: float = 1.0,
norm_first: bool = True):
pre_ln: bool = True):
super().__init__()
self.nhead = nhead
self.shared_embedding = None
Expand All @@ -282,9 +282,9 @@ def __init__(self,
glu=glu,
layer_norm_eps=layer_norm_eps,
attention_temp=attention_temp,
norm_first=norm_first)
pre_ln=pre_ln)
encoder_norm = (
nn.LayerNorm(d_model, eps=layer_norm_eps) if norm_first else None)
nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm)

def forward(self,
Expand Down Expand Up @@ -312,7 +312,7 @@ def __init__(self,
glu: bool = False,
layer_norm_eps: float = 1e-6,
attention_temp: float = 1.0,
norm_first: bool = True):
pre_ln: bool = True):
super().__init__()
self.nhead = nhead
self.shared_embedding = None
Expand All @@ -327,7 +327,7 @@ def __init__(self,
layer_norm_eps,
nlayers,
attention_temp,
norm_first)
pre_ln)

def forward(
self,
Expand Down Expand Up @@ -443,15 +443,15 @@ class TransformerEncoderLayer(nn.Module):
string ("relu" or "gelu") or a unary callable (default=F.relu).
layer_norm_eps: the eps value in layer normalization components
(default=1e-6).
norm_first: if ``True``, layer norm is done prior to attention and
pre_ln: if ``True``, layer norm is done prior to attention and
feedforward operations, respectivaly. Otherwise it's done after.
Default: ``True``.
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
"""
__constants__ = ['norm_first']
__constants__ = ['pre_ln']

def __init__(self,
d_model: int = 1024,
Expand All @@ -463,7 +463,7 @@ def __init__(self,
glu: bool = False,
layer_norm_eps: float = 1e-6,
attention_temp: float = 1.0,
norm_first: bool = True,
pre_ln: bool = True,
device=None,
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
Expand All @@ -485,7 +485,7 @@ def __init__(self,
self.dropout = nn.Dropout(dropout_rate)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

self.norm_first = norm_first
self.pre_ln = pre_ln
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = nn.Dropout(dropout_rate)
Expand All @@ -504,7 +504,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
see the docs in Transformer class.
"""
x = src
if self.norm_first:
if self.pre_ln:
x = x + self._sa_block(self.norm1(x), src_mask)
x = x + self._ff_block(self.norm2(x))
else:
Expand Down Expand Up @@ -562,7 +562,7 @@ def __init__(self,
layer_norm_eps,
num_layers,
attention_temp,
norm_first):
pre_ln):
super().__init__()
self.layers = nn.ModuleList([
TransformerDecoderLayer(
Expand All @@ -575,11 +575,11 @@ def __init__(self,
glu,
layer_norm_eps=layer_norm_eps,
attention_temp=attention_temp,
norm_first=norm_first) for _ in range(num_layers)
pre_ln=pre_ln) for _ in range(num_layers)
])
self.num_layers = num_layers
self.norm = (
nn.LayerNorm(d_model, eps=layer_norm_eps) if norm_first else None)
nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)

def forward(self,
tgt: Tensor,
Expand Down Expand Up @@ -642,7 +642,7 @@ class TransformerDecoderLayer(nn.Module):
string ("relu" or "gelu") or a unary callable (default=F.relu).
layer_norm_eps: the eps value in layer normalization components
(default=1e-6).
norm_first: if ``True``, layer norm is done prior to self attention,
pre_ln: if ``True``, layer norm is done prior to self attention,
multihead attention and feedforward operations, respectivaly.
Otherwise it's done after. Default: ``True``.
Examples::
Expand All @@ -651,7 +651,7 @@ class TransformerDecoderLayer(nn.Module):
>>> tgt = torch.rand(32, 20, 512)
>>> out = decoder_layer(tgt, memory)
"""
__constants__ = ['norm_first']
__constants__ = ['pre_ln']

def __init__(self,
d_model: int = 1024,
Expand All @@ -662,7 +662,7 @@ def __init__(self,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
norm_first: bool = True,
pre_ln: bool = True,
attention_temp: float = 1.0,
device=None,
dtype=None) -> None:
Expand Down Expand Up @@ -695,7 +695,7 @@ def __init__(self,
self.dropout = nn.Dropout(dropout_rate)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

self.norm_first = norm_first
self.pre_ln = pre_ln
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
Expand Down Expand Up @@ -729,7 +729,7 @@ def forward( # pylint: disable=arguments-renamed
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

x = tgt
if self.norm_first:
if self.pre_ln:
sa_out, cache = self._sa_block(
self.norm1(x),
tgt_mask,
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def init_model_fn(
model = Transformer(
dropout_rate=dropout_rate,
attention_dropout_rate=aux_dropout_rate,
norm_first=self.norm_first,
pre_ln=self.pre_ln,
attention_temp=self.attention_temp,
activation=activation,
glu=self.glu)
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_target_value(self) -> float:
return 29.8982

@property
def norm_first(self) -> bool:
def pre_ln(self) -> bool:
return False


Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/wmt/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def step_hint(self) -> int:
return 133_333

@property
def norm_first(self) -> bool:
def pre_ln(self) -> bool:
return True

@property
Expand Down

0 comments on commit a026ce7

Please sign in to comment.