Skip to content

Commit

Permalink
juhan/imagenet_vit_variant
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Dec 6, 2023
1 parent 98146be commit 269a7b6
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import math

import tensorflow as tf
from tensorflow_addons import image as contrib_image
# from tensorflow_addons import image as contrib_image

# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def posemb_sincos_2d(h: int,
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
use_glu: bool = False
dropout_rate: float = 0.0

@nn.compact
Expand All @@ -46,7 +47,10 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:

d = x.shape[2]
x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x)
x = nn.gelu(x)
if self.use_glu:
x = nn.glu(x)
else:
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, train)
x = nn.Dense(d, **inits)(x)
return x
Expand All @@ -56,11 +60,16 @@ class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
use_glu: bool = False
post_layer_norm: bool = False
dropout_rate: float = 0.0

@nn.compact
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
y = nn.LayerNorm(name='LayerNorm_0')(x)
if not self.post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
else:
y = x
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
Expand All @@ -69,13 +78,20 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
if self.post_layer_norm:
x = nn.LayerNorm(name='LayerNorm_0')(x)

y = nn.LayerNorm(name='LayerNorm_2')(x)
if not self.post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_2')(x)
else:
y = x
y = MlpBlock(
mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate,
mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
if self.post_layer_norm:
x = nn.LayerNorm(name='LayerNorm_2')(x)
return x


Expand All @@ -84,6 +100,8 @@ class Encoder(nn.Module):
depth: int
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
use_glu: bool = False
post_layer_norm: bool = False
dropout_rate: float = 0.0

@nn.compact
Expand All @@ -94,9 +112,43 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
name=f'encoderblock_{lyr}',
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate)
use_glu=self.use_glu,
post_layer_norm=self.post_layer_norm,
dropout_rate=self.dropout_rate,)
x = block(x, train)
return nn.LayerNorm(name='encoder_layernorm')(x)

if not self.post_layer_norm:
return nn.LayerNorm(name='encoder_layernorm')(x)
else:
return x


class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12
post_layer_norm: bool = False

@nn.compact
def __call__(self, x):
n, _, d = x.shape
probe = self.param('probe', nn.initializers.xavier_uniform(),
(1, 1, d), x.dtype)
probe = jnp.tile(probe, [n, 1, 1])

x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform())(probe, x)

if self.post_layer_norm:
y = x
x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
x = nn.LayerNorm('MAP')(x)
else:
y = nn.LayerNorm('MAP')(x)
x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)

return x[:, 0]


class ViT(nn.Module):
Expand All @@ -109,9 +161,12 @@ class ViT(nn.Module):
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
rep_size: Union[int, bool] = True
post_layer_norm: bool = False
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
pool_type: str = 'gap'
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False

def get_posemb(self,
seqshape: tuple,
Expand Down Expand Up @@ -146,10 +201,21 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout_rate=dropout_rate,
use_glu=self.use_glu,
post_layer_norm=self.post_layer_norm,
name='Transformer')(
x, train=not train)

x = jnp.mean(x, axis=1)
if self.pool_type == 'map':
x = MAPHead(
num_heads=self.num_heads,
mlp_dim=self.mlp_dim,
post_layer_norm=self.post_layer_norm)(
x)
elif self.pool_type == 'gap':
x = jnp.mean(x, axis=1)
else:
raise ValueError(f'Unknown pool type: "{self.pool_type}"')

if self.rep_size:
rep_size = self.width if self.rep_size is True else self.rep_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def init_model_fn(
self._model = models.ViT(
dropout_rate=dropout_rate,
num_classes=self._num_classes,
use_glu=self.use_glu,
**decode_variant('S/16'))
params, model_state = self.initialized(rng, self._model)
self._param_shapes = param_utils.jax_param_shapes(params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(
self,
width: int,
mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
dropout_rate: float = 0.0) -> None:
use_glu: bool = False,
dropout_rate: float = 0.0,
) -> None:
super().__init__()

self.width = width
Expand All @@ -48,7 +50,7 @@ def __init__(

self.net = nn.Sequential(
nn.Linear(self.width, self.mlp_dim),
nn.GELU(),
nn.GLU() if use_glu else nn.GELU(),
nn.Dropout(self.dropout_rate),
nn.Linear(self.mlp_dim, self.width))
self.reset_parameters()
Expand Down Expand Up @@ -129,29 +131,42 @@ def __init__(self,
width: int,
mlp_dim: Optional[int] = None,
num_heads: int = 12,
use_glu: bool = False,
post_layer_norm: bool = False,
dropout_rate: float = 0.0) -> None:
super().__init__()

self.width = width
self.mlp_dim = mlp_dim
self.num_heads = num_heads
self.post_layer_norm = post_layer_norm

self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6)
self.self_attention1 = SelfAttention(self.width, self.num_heads)
self.dropout = nn.Dropout(dropout_rate)
self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6)
self.mlp3 = MlpBlock(self.width, self.mlp_dim, dropout_rate)
self.mlp3 = MlpBlock(self.width, self.mlp_dim, use_glu, dropout_rate)

def forward(self, x: spec.Tensor) -> spec.Tensor:
y = self.layer_norm0(x)
if not self.post_layer_norm:
y = self.layer_norm0(x)
else:
y = x
y = self.self_attention1(y)
y = self.dropout(y)
x = x + y
if self.post_layer_norm:
x = self.layer_norm0(x)

y = self.layer_norm2(x)
if not self.post_layer_norm:
y = self.layer_norm2(x)
else:
y = x
y = self.mlp3(y)
y = self.dropout(y)
x = x + y
if self.post_layer_norm:
x = self.layer_norm2(x)
return x


Expand All @@ -163,6 +178,8 @@ def __init__(self,
width: int,
mlp_dim: Optional[int] = None,
num_heads: int = 12,
use_glu: bool = False,
post_layer_norm: bool = False,
dropout_rate: float = 0.0) -> None:
super().__init__()

Expand All @@ -172,7 +189,7 @@ def __init__(self,
self.num_heads = num_heads

self.net = nn.ModuleList([
Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, dropout_rate)
Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, use_glu, post_layer_norm, dropout_rate)
for _ in range(depth)
])
self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6)
Expand All @@ -184,6 +201,19 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
return self.encoder_norm(x)


# class MAPHead(nn.Module):
# def __int__(self, mlp_dim: Optional[int] = None, num_heads: int = 12, post_layer_norm: bool = False):
# super().__init__()
# self.mlp_dim = mlp_dim
# self.num_heads = num_heads
# self.post_layer_norm = post_layer_norm
#
# self.probe = nn.Parameter((1, 1, self.mlp_dim))
#
# def forward(self, x: spec.Tensor) -> spec.Tensor:
# n, _, d = x.shape


class ViT(nn.Module):
"""ViT model."""

Expand All @@ -202,6 +232,8 @@ def __init__(
rep_size: Union[int, bool] = True,
dropout_rate: Optional[float] = 0.0,
head_zeroinit: bool = True,
post_layer_norm: bool = False,
use_glu: bool = False,
dtype: Any = torch.float32) -> None:
super().__init__()
if dropout_rate is None:
Expand Down Expand Up @@ -234,6 +266,8 @@ def __init__(
width=self.width,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
use_glu=use_glu,
post_layer_norm=post_layer_norm,
dropout_rate=dropout_rate)

if self.num_classes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def init_model_fn(
model = models.ViT(
dropout_rate=dropout_rate,
num_classes=self._num_classes,
use_glu=self.use_glu,
**decode_variant('S/16'))
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
Expand Down Expand Up @@ -77,3 +78,10 @@ def model_fn(
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])

return logits_batch, None


class ImagenetVitPostLayerNormWorkload(ImagenetVitWorkload):
@property
def use_post_layer_norm(self) -> bool:
"""Whether to use layer normalization after the residual branch."""
return True
15 changes: 15 additions & 0 deletions algorithmic_efficiency/workloads/imagenet_vit/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ def validation_target_value(self) -> float:
def test_target_value(self) -> float:
return 1 - 0.3481 # 0.6519

@property
def use_post_layer_norm(self) -> bool:
"""Whether to use layer normalization after the residual branch."""
return False

@property
def use_map(self) -> bool:
"""Whether to use multihead attention pooling."""
return False

@property
def use_glu(self) -> bool:
"""Whether to use GLU in the MLPBlock."""
return False

@property
def eval_batch_size(self) -> int:
return 2048
Expand Down
Empty file.

0 comments on commit 269a7b6

Please sign in to comment.