From 711f6308fd02e23fae0d27514e08b06bb1fd5059 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 11:33:40 -0500 Subject: [PATCH 001/155] minor --- .../imagenet_jax/randaugment.py | 2 +- .../imagenet_vit/imagenet_jax/models.py | 103 +++++++++-- .../imagenet_vit/imagenet_jax/workload.py | 28 ++- .../imagenet_vit/imagenet_pytorch/models.py | 121 ++++++++++--- .../imagenet_vit/imagenet_pytorch/workload.py | 28 ++- .../workloads/imagenet_vit/workload.py | 15 ++ tests/modeldiffs/imagenet_vit/compare.py | 67 ++++++- tests/modeldiffs/imagenet_vit/compare_glu.py | 163 ++++++++++++++++++ .../imagenet_vit/compare_post_ln.py | 163 ++++++++++++++++++ 9 files changed, 644 insertions(+), 46 deletions(-) create mode 100644 tests/modeldiffs/imagenet_vit/compare_glu.py create mode 100644 tests/modeldiffs/imagenet_vit/compare_post_ln.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..8fa1c0789 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,7 +8,7 @@ import math import tensorflow as tf -from tensorflow_addons import image as contrib_image +# from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index ab5d1839e..4a97ee661 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -34,6 +34,7 @@ def posemb_sincos_2d(h: int, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. + use_glu: bool = False dropout_rate: float = 0.0 @nn.compact @@ -47,6 +48,13 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: d = x.shape[2] x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) x = nn.gelu(x) + + if self.use_glu: + y = nn.Dense( + self.mlp_dim, + **inits)(x) + x = x * y + x = nn.Dropout(rate=self.dropout_rate)(x, train) x = nn.Dense(d, **inits)(x) return x @@ -56,26 +64,47 @@ class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False dropout_rate: float = 0.0 @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: - y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - - y = nn.LayerNorm(name='LayerNorm_2')(x) - y = MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y + if not self.use_post_layer_norm: + y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + + y = nn.LayerNorm(name='LayerNorm_2')(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + else: + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + x) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_0')(x) + + y = MlpBlock( + mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, + name='MlpBlock_3')(x, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_2')(x) + return x @@ -85,6 +114,8 @@ class Encoder(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 dropout_rate: float = 0.0 + use_glu: bool = False + use_post_layer_norm: bool = False @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: @@ -94,9 +125,35 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: name=f'encoderblock_{lyr}', mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=self.dropout_rate) x = block(x, train) - return nn.LayerNorm(name='encoder_layernorm')(x) + if not self.use_post_layer_norm: + return nn.LayerNorm(name='encoder_layernorm')(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + @nn.compact + def __call__(self, x): + n, _, d = x.shape + probe = self.param('probe', + nn.initializers.xavier_uniform(), + (1, 1, d), x.dtype) + probe = jnp.tile(probe, [n, 1, 1]) + + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform())(probe, x) + + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + return x[:, 0] class ViT(nn.Module): @@ -112,6 +169,9 @@ class ViT(nn.Module): dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, def get_posemb(self, seqshape: tuple, @@ -145,11 +205,18 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=dropout_rate, name='Transformer')( x, train=not train) - x = jnp.mean(x, axis=1) + if self.use_map: + x = MAPHead(num_heads=self.num_heads, + mlp_dim=self.mlp_dim + )(x) + else: + x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 3f3af0564..22fcde66a 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,11 +32,16 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + head_zeroinit: bool = True) -> spec.ModelInitState: del aux_dropout_rate self._model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + head_zeroinit=head_zeroinit, **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) @@ -83,3 +88,24 @@ def _eval_model_on_split(self, rng, data_dir, global_step) + + +class ImagenetVitGluWorkload(ImagenetVitWorkload): + + @property + def use_glu(self) -> bool: + return True + + +class ImagenetViTPostLNWorkload(ImagenetVitWorkload): + + @property + def use_post_layer_norm(self) -> bool: + return True + + +class ImagenetViTMapLNWorkload(ImagenetVitWorkload): + + @property + def use_map(self) -> bool: + return True diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py index 55a8e370d..053b0ec76 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -39,18 +39,26 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + use_glu: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width + self.use_glu = use_glu self.dropout_rate = dropout_rate - self.net = nn.Sequential( - nn.Linear(self.width, self.mlp_dim), - nn.GELU(), - nn.Dropout(self.dropout_rate), - nn.Linear(self.mlp_dim, self.width)) + self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.act_fnc = nn.GELU(approximate='tanh') + self.dropout = nn.Dropout(self.dropout_rate) + + if self.use_glu: + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + else: + self.glu_linear = None + + self.linear2 = nn.Linear(self.mlp_dim, self.width) + self.reset_parameters() def reset_parameters(self) -> None: @@ -61,7 +69,16 @@ def reset_parameters(self) -> None: module.bias.data.normal_(std=1e-6) def forward(self, x: spec.Tensor) -> spec.Tensor: - return self.net(x) + x = self.linear1(x) + x = self.act_fnc(x) + + if self.use_glu: + y = self.glu_linear(x) + x = x * y + + x = self.dropout(x) + x = self.linear2(x) + return x class SelfAttention(nn.Module): @@ -129,29 +146,44 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock(self.width, self.mlp_dim, dropout_rate) + self.mlp3 = MlpBlock(width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=dropout_rate) def forward(self, x: spec.Tensor) -> spec.Tensor: - y = self.layer_norm0(x) - y = self.self_attention1(y) - y = self.dropout(y) - x = x + y - - y = self.layer_norm2(x) - y = self.mlp3(y) - y = self.dropout(y) - x = x + y + if not self.use_post_layer_norm: + y = self.layer_norm0(x) + y = self.self_attention1(y) + y = self.dropout(y) + x = x + y + + y = self.layer_norm2(x) + y = self.mlp3(y) + y = self.dropout(y) + x = x + y + else: + y = self.self_attention1(x) + y = self.dropout(y) + x = x + y + x = self.layer_norm0(x) + + y = self.mlp3(x) + y = self.dropout(y) + x = x + y + x = self.layer_norm2(x) return x @@ -163,6 +195,8 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() @@ -170,18 +204,53 @@ def __init__(self, self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm self.net = nn.ModuleList([ - Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, dropout_rate) + Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, self.use_glu, self.use_post_layer_norm, dropout_rate) for _ in range(depth) ]) - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + + if not self.use_post_layer_norm: + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + else: + self.encoder_norm = None def forward(self, x: spec.Tensor) -> spec.Tensor: # Input Encoder. for block in self.net: x = block(x) - return self.encoder_norm(x) + if not self.use_post_layer_norm: + return self.encoder_norm(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12): + super().__init__() + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) + nn.init.xavier_uniform_(self.probe.data) + + self.mha = nn.MultiheadAttention(embed_dim=self.width, num_heads=self.num_heads) + self.layer_nrom = nn.LayerNorm(self.width, eps=1e-6) + self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) + + def forward(self, x): + n, _, _ = x.shape + probe = torch.tile(self.probe, [n, 1, 1]) + + x = self.mha(probe, x) + y = self.layer_nrom(x) + x = x + self.mlp(y) + return x[:, 0] class ViT(nn.Module): @@ -202,6 +271,9 @@ def __init__( rep_size: Union[int, bool] = True, dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() if dropout_rate is None: @@ -215,6 +287,9 @@ def __init__( self.num_heads = num_heads self.rep_size = rep_size self.head_zeroinit = head_zeroinit + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + self.use_map = use_map self.dtype = dtype if self.rep_size: @@ -234,6 +309,8 @@ def __init__( width=self.width, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=dropout_rate) if self.num_classes: @@ -270,7 +347,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = self.dropout(x) x = self.encoder(x) - x = torch.mean(x, dim=1) + + if self.use_map: + pass + else: + x = torch.mean(x, dim=1) if self.rep_size: x = torch.tanh(self.pre_logits(x)) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 08a62ede6..9e8af3a68 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -28,12 +28,17 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + head_zeroinit: bool = True) -> spec.ModelInitState: del aux_dropout_rate torch.random.manual_seed(rng[0]) model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + head_zeroinit=head_zeroinit, **decode_variant('S/16')) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) @@ -77,3 +82,24 @@ def model_fn( logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None + + +class ImagenetVitGluWorkload(ImagenetVitWorkload): + + @property + def use_glu(self) -> bool: + return True + + +class ImagenetViTPostLNWorkload(ImagenetVitWorkload): + + @property + def use_post_layer_norm(self) -> bool: + return True + + +class ImagenetViTMapLNWorkload(ImagenetVitWorkload): + + @property + def use_map(self) -> bool: + return True diff --git a/algorithmic_efficiency/workloads/imagenet_vit/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/workload.py index 61d3acfd3..ed0118ca0 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/workload.py @@ -60,6 +60,21 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: return 1 - 0.3481 # 0.6519 + @property + def use_post_layer_norm(self) -> bool: + """Whether to use layer normalization after the residual branch.""" + return False + + @property + def use_map(self) -> bool: + """Whether to use multihead attention pooling.""" + return False + + @property + def use_glu(self) -> bool: + """Whether to use GLU in the MLPBlock.""" + return False + @property def eval_batch_size(self) -> int: return 2048 diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 1022b5b54..3e8b9dcb1 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -3,20 +3,75 @@ # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' -import jax -import torch - from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload -from tests.modeldiffs.diff import out_diff +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) def key_transform(k): if 'Conv' in k[0]: - k = ('embedding', *k[1:]) + k = ('conv_patch_extract', *k[1:]) elif k[0] == 'Linear_0': k = ('pre_logits', *k[1:]) elif k[0] == 'Linear_1': @@ -35,6 +90,8 @@ def key_transform(k): continue if 'CustomBatchNorm' in i: continue + if 'GLU' in i: + pass if 'Linear' in i: if attention: i = { diff --git a/tests/modeldiffs/imagenet_vit/compare_glu.py b/tests/modeldiffs/imagenet_vit/compare_glu.py new file mode 100644 index 000000000..a6f01f971 --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/compare_glu.py @@ -0,0 +1,163 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetVitGluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetVitGluWorkload as PytWorkload +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) + + +def key_transform(k): + if 'Conv' in k[0]: + k = ('conv_patch_extract', *k[1:]) + elif k[0] == 'Linear_0': + k = ('pre_logits', *k[1:]) + elif k[0] == 'Linear_1': + k = ('head', *k[1:]) + + new_key = [] + bn = False + attention = False + ln = False + enc_block = False + for idx, i in enumerate(k): + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + attention = attention or 'SelfAttention' in i + if 'ModuleList' in i or 'Sequential' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'GLU' in i: + pass + if 'Linear' in i: + if attention: + i = { + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', + }[i] + else: + i = i.replace('Linear', 'Dense') + elif 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + elif 'Encoder1DBlock' in i: + i = i.replace('Encoder1DBlock', 'encoderblock') + enc_block = True + elif 'Encoder' in i: + i = 'Transformer' + elif enc_block and 'SelfAttention' in i: + i = 'MultiHeadDotProductAttention_1' + elif enc_block and i == 'LayerNorm_1': + i = 'LayerNorm_2' + elif enc_block and 'MlpBlock' in i: + i = 'MlpBlock_3' + elif idx == 1 and i == 'LayerNorm_0': + i = 'encoder_layernorm' + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) diff --git a/tests/modeldiffs/imagenet_vit/compare_post_ln.py b/tests/modeldiffs/imagenet_vit/compare_post_ln.py new file mode 100644 index 000000000..e27d77482 --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/compare_post_ln.py @@ -0,0 +1,163 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetViTPostLNWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetViTPostLNWorkload as PytWorkload +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) + + +def key_transform(k): + if 'Conv' in k[0]: + k = ('conv_patch_extract', *k[1:]) + elif k[0] == 'Linear_0': + k = ('pre_logits', *k[1:]) + elif k[0] == 'Linear_1': + k = ('head', *k[1:]) + + new_key = [] + bn = False + attention = False + ln = False + enc_block = False + for idx, i in enumerate(k): + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + attention = attention or 'SelfAttention' in i + if 'ModuleList' in i or 'Sequential' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'GLU' in i: + pass + if 'Linear' in i: + if attention: + i = { + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', + }[i] + else: + i = i.replace('Linear', 'Dense') + elif 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + elif 'Encoder1DBlock' in i: + i = i.replace('Encoder1DBlock', 'encoderblock') + enc_block = True + elif 'Encoder' in i: + i = 'Transformer' + elif enc_block and 'SelfAttention' in i: + i = 'MultiHeadDotProductAttention_1' + elif enc_block and i == 'LayerNorm_1': + i = 'LayerNorm_2' + elif enc_block and 'MlpBlock' in i: + i = 'MlpBlock_3' + elif idx == 1 and i == 'LayerNorm_0': + i = 'encoder_layernorm' + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) From c8a9e728486ea6faffd11ac8df4e90876377b0a7 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:20:40 -0500 Subject: [PATCH 002/155] Clean up model diff --- .../imagenet_jax/randaugment.py | 1 + .../imagenet_vit/imagenet_jax/models.py | 6 +- .../imagenet_vit/imagenet_jax/workload.py | 4 +- .../imagenet_vit/imagenet_pytorch/models.py | 51 ++++-- .../imagenet_vit/imagenet_pytorch/workload.py | 6 +- tests/modeldiffs/imagenet_vit/compare.py | 66 +------ tests/modeldiffs/imagenet_vit/compare_glu.py | 163 ------------------ .../imagenet_vit/compare_post_ln.py | 163 ------------------ tests/modeldiffs/imagenet_vit/glu_compare.py | 52 ++++++ .../imagenet_vit/post_ln_compare.py | 52 ++++++ 10 files changed, 154 insertions(+), 410 deletions(-) delete mode 100644 tests/modeldiffs/imagenet_vit/compare_glu.py delete mode 100644 tests/modeldiffs/imagenet_vit/compare_post_ln.py create mode 100644 tests/modeldiffs/imagenet_vit/glu_compare.py create mode 100644 tests/modeldiffs/imagenet_vit/post_ln_compare.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 8fa1c0789..caa77ae35 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,6 +8,7 @@ import math import tensorflow as tf + # from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 4a97ee661..c88132621 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -88,19 +88,21 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y else: + y = x y = nn.SelfAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1')( - x) + y) y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) + y = x y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(x, train) + name='MlpBlock_3')(y, train) y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 22fcde66a..1acd58bcd 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,8 +32,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, - head_zeroinit: bool = True) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate self._model = models.ViT( dropout_rate=dropout_rate, @@ -41,7 +40,6 @@ def init_model_fn( use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - head_zeroinit=head_zeroinit, **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py index 053b0ec76..469716d59 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -1,8 +1,8 @@ """PyTorch implementation of refactored and simplified ViT. Adapted from: -https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit. -https://github.com/lucidrains/vit-pytorch. +https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit +and https://github.com/lucidrains/vit-pytorch. """ import math @@ -14,9 +14,12 @@ from algorithmic_efficiency import init_utils from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \ + MultiheadAttention def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: + """Follows the MoCo v3 logic.""" _, width, h, w = patches.shape device = patches.device y, x = torch.meshgrid(torch.arange(h, device=device), @@ -161,7 +164,11 @@ def __init__(self, self.self_attention1 = SelfAttention(self.width, self.num_heads) self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock(width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=dropout_rate) + self.mlp3 = MlpBlock( + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=dropout_rate) def forward(self, x: spec.Tensor) -> spec.Tensor: if not self.use_post_layer_norm: @@ -175,12 +182,14 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: y = self.dropout(y) x = x + y else: - y = self.self_attention1(x) + y = x + y = self.self_attention1(y) y = self.dropout(y) x = x + y x = self.layer_norm0(x) - y = self.mlp3(x) + y = x + y = self.mlp3(y) y = self.dropout(y) x = x + y x = self.layer_norm2(x) @@ -208,8 +217,12 @@ def __init__(self, self.use_post_layer_norm = use_post_layer_norm self.net = nn.ModuleList([ - Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, self.use_glu, self.use_post_layer_norm, dropout_rate) - for _ in range(depth) + Encoder1DBlock(self.width, + self.mlp_dim, + self.num_heads, + self.use_glu, + self.use_post_layer_norm, + dropout_rate) for _ in range(depth) ]) if not self.use_post_layer_norm: @@ -230,7 +243,10 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class MAPHead(nn.Module): """Multihead Attention Pooling.""" - def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12): + def __init__(self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12): super().__init__() self.width = width self.mlp_dim = mlp_dim @@ -239,16 +255,17 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 1 self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) nn.init.xavier_uniform_(self.probe.data) - self.mha = nn.MultiheadAttention(embed_dim=self.width, num_heads=self.num_heads) - self.layer_nrom = nn.LayerNorm(self.width, eps=1e-6) + self.mha = MultiheadAttention( + self.width, num_heads=self.num_heads, self_attn=False, bias=False) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x): + def forward(self, x: spec.Tensor) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x) - y = self.layer_nrom(x) + x = self.mha(probe, x)[0] + y = self.layer_norm(x) x = x + self.mlp(y) return x[:, 0] @@ -315,6 +332,12 @@ def __init__( if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) + + if self.use_map: + self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + else: + self.map = None + self.reset_parameters() def reset_parameters(self) -> None: @@ -349,7 +372,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = self.encoder(x) if self.use_map: - pass + x = self.map(x) else: x = torch.mean(x, dim=1) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 9e8af3a68..013bc643f 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -28,8 +28,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, - head_zeroinit: bool = True) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate torch.random.manual_seed(rng[0]) model = models.ViT( @@ -38,7 +37,6 @@ def init_model_fn( use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - head_zeroinit=head_zeroinit, **decode_variant('S/16')) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) @@ -98,7 +96,7 @@ def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapLNWorkload(ImagenetVitWorkload): +class ImagenetViTMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 3e8b9dcb1..39f2651a0 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -1,72 +1,18 @@ import os +from tests.modeldiffs.diff import out_diff + # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' +import jax +import torch + from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) def key_transform(k): @@ -90,8 +36,6 @@ def key_transform(k): continue if 'CustomBatchNorm' in i: continue - if 'GLU' in i: - pass if 'Linear' in i: if attention: i = { diff --git a/tests/modeldiffs/imagenet_vit/compare_glu.py b/tests/modeldiffs/imagenet_vit/compare_glu.py deleted file mode 100644 index a6f01f971..000000000 --- a/tests/modeldiffs/imagenet_vit/compare_glu.py +++ /dev/null @@ -1,163 +0,0 @@ -import os - -# Disable GPU access for both jax and pytorch. -os.environ['CUDA_VISIBLE_DEVICES'] = '' - -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitGluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitGluWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) - - -def key_transform(k): - if 'Conv' in k[0]: - k = ('conv_patch_extract', *k[1:]) - elif k[0] == 'Linear_0': - k = ('pre_logits', *k[1:]) - elif k[0] == 'Linear_1': - k = ('head', *k[1:]) - - new_key = [] - bn = False - attention = False - ln = False - enc_block = False - for idx, i in enumerate(k): - bn = bn or 'BatchNorm' in i - ln = ln or 'LayerNorm' in i - attention = attention or 'SelfAttention' in i - if 'ModuleList' in i or 'Sequential' in i: - continue - if 'CustomBatchNorm' in i: - continue - if 'GLU' in i: - pass - if 'Linear' in i: - if attention: - i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', - }[i] - else: - i = i.replace('Linear', 'Dense') - elif 'Conv2d' in i: - i = i.replace('Conv2d', 'Conv') - elif 'Encoder1DBlock' in i: - i = i.replace('Encoder1DBlock', 'encoderblock') - enc_block = True - elif 'Encoder' in i: - i = 'Transformer' - elif enc_block and 'SelfAttention' in i: - i = 'MultiHeadDotProductAttention_1' - elif enc_block and i == 'LayerNorm_1': - i = 'LayerNorm_2' - elif enc_block and 'MlpBlock' in i: - i = 'MlpBlock_3' - elif idx == 1 and i == 'LayerNorm_0': - i = 'encoder_layernorm' - elif 'weight' in i: - if bn or ln: - i = i.replace('weight', 'scale') - else: - i = i.replace('weight', 'kernel') - new_key.append(i) - return tuple(new_key) - - -sd_transform = None - -if __name__ == '__main__': - # pylint: disable=locally-disabled, not-callable - - jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() - - # Test outputs for identical weights and inputs. - image = torch.randn(2, 3, 224, 224) - - jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} - - pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) - - jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) - - out_diff( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None, - ) diff --git a/tests/modeldiffs/imagenet_vit/compare_post_ln.py b/tests/modeldiffs/imagenet_vit/compare_post_ln.py deleted file mode 100644 index e27d77482..000000000 --- a/tests/modeldiffs/imagenet_vit/compare_post_ln.py +++ /dev/null @@ -1,163 +0,0 @@ -import os - -# Disable GPU access for both jax and pytorch. -os.environ['CUDA_VISIBLE_DEVICES'] = '' - -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetViTPostLNWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetViTPostLNWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) - - -def key_transform(k): - if 'Conv' in k[0]: - k = ('conv_patch_extract', *k[1:]) - elif k[0] == 'Linear_0': - k = ('pre_logits', *k[1:]) - elif k[0] == 'Linear_1': - k = ('head', *k[1:]) - - new_key = [] - bn = False - attention = False - ln = False - enc_block = False - for idx, i in enumerate(k): - bn = bn or 'BatchNorm' in i - ln = ln or 'LayerNorm' in i - attention = attention or 'SelfAttention' in i - if 'ModuleList' in i or 'Sequential' in i: - continue - if 'CustomBatchNorm' in i: - continue - if 'GLU' in i: - pass - if 'Linear' in i: - if attention: - i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', - }[i] - else: - i = i.replace('Linear', 'Dense') - elif 'Conv2d' in i: - i = i.replace('Conv2d', 'Conv') - elif 'Encoder1DBlock' in i: - i = i.replace('Encoder1DBlock', 'encoderblock') - enc_block = True - elif 'Encoder' in i: - i = 'Transformer' - elif enc_block and 'SelfAttention' in i: - i = 'MultiHeadDotProductAttention_1' - elif enc_block and i == 'LayerNorm_1': - i = 'LayerNorm_2' - elif enc_block and 'MlpBlock' in i: - i = 'MlpBlock_3' - elif idx == 1 and i == 'LayerNorm_0': - i = 'encoder_layernorm' - elif 'weight' in i: - if bn or ln: - i = i.replace('weight', 'scale') - else: - i = i.replace('weight', 'kernel') - new_key.append(i) - return tuple(new_key) - - -sd_transform = None - -if __name__ == '__main__': - # pylint: disable=locally-disabled, not-callable - - jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() - - # Test outputs for identical weights and inputs. - image = torch.randn(2, 3, 224, 224) - - jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} - - pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) - - jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) - - out_diff( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None, - ) diff --git a/tests/modeldiffs/imagenet_vit/glu_compare.py b/tests/modeldiffs/imagenet_vit/glu_compare.py new file mode 100644 index 000000000..444f1230a --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/glu_compare.py @@ -0,0 +1,52 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetVitGluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetVitGluWorkload as PytWorkload + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) diff --git a/tests/modeldiffs/imagenet_vit/post_ln_compare.py b/tests/modeldiffs/imagenet_vit/post_ln_compare.py new file mode 100644 index 000000000..8bf0bef7e --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/post_ln_compare.py @@ -0,0 +1,52 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetViTPostLNWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetViTPostLNWorkload as PytWorkload + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) From ecf8220edf11ecde32511f4dbe97888307b2cf86 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:33:38 -0500 Subject: [PATCH 003/155] Add docker image --- .../imagenet_resnet/imagenet_jax/randaugment.py | 3 +-- algorithmic_efficiency/workloads/workloads.py | 12 ++++++++++++ docker/scripts/startup.sh | 3 ++- tests/modeldiffs/imagenet_vit/compare.py | 3 +-- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index caa77ae35..5f92b1482 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,8 +8,7 @@ import math import tensorflow as tf - -# from tensorflow_addons import image as contrib_image +from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 6cc53b7dd..bf444ea36 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -56,6 +56,18 @@ 'workload_path': 'imagenet_vit/imagenet', 'workload_class_name': 'ImagenetVitWorkload', }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetViTPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetViTMapLNWorkload', + }, 'librispeech_conformer': { 'workload_path': 'librispeech_conformer/librispeech', 'workload_class_name': 'LibriSpeechConformerWorkload', diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 3f7458e4b..3b366b71c 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -113,7 +113,8 @@ done VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ "wmt" "mnist") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ - "imagenet_resnet_large_bn_init" "imagenet_vit" "fastmri" "ogbg" \ + "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ + "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ "wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist" \ "criteo1tb_resnet" "criteo1tb_layernorm" "criteo_embed_init" \ "conformer_layernorm" "conformer_attention_temperature" \ diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 39f2651a0..bf7d6dfa5 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -1,7 +1,5 @@ import os -from tests.modeldiffs.diff import out_diff - # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' @@ -13,6 +11,7 @@ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff def key_transform(k): From 290807795fc8a1cf392dd7e94823569d5b651e40 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:40:36 -0500 Subject: [PATCH 004/155] Lint fix --- .../imagenet_vit/imagenet_jax/models.py | 91 ++++++++++--------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index c88132621..32e748ec7 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -50,9 +50,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: x = nn.gelu(x) if self.use_glu: - y = nn.Dense( - self.mlp_dim, - **inits)(x) + y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y x = nn.Dropout(rate=self.dropout_rate)(x, train) @@ -71,41 +69,45 @@ class Encoder1DBlock(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - - y = nn.LayerNorm(name='LayerNorm_2')(x) - y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y + y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + + y = nn.LayerNorm(name='LayerNorm_2')(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y else: - y = x - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - x = nn.LayerNorm(name='LayerNorm_0')(x) - - y = x - y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - x = nn.LayerNorm(name='LayerNorm_2')(x) + y = x + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_0')(x) + + y = x + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_2')(x) return x @@ -141,12 +143,13 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 + @nn.compact def __call__(self, x): n, _, d = x.shape probe = self.param('probe', - nn.initializers.xavier_uniform(), - (1, 1, d), x.dtype) + nn.initializers.xavier_uniform(), (1, 1, d), + x.dtype) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( @@ -171,9 +174,9 @@ class ViT(nn.Module): dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True - use_glu: bool = False, - use_post_layer_norm: bool = False, - use_map: bool = False, + use_glu: bool = False + use_post_layer_norm: bool = False + use_map: bool = False def get_posemb(self, seqshape: tuple, @@ -214,9 +217,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: x, train=not train) if self.use_map: - x = MAPHead(num_heads=self.num_heads, - mlp_dim=self.mlp_dim - )(x) + x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) else: x = jnp.mean(x, axis=1) From efa1120c49eaf7333cafcf2ee6a21b27d9044789 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 13 Dec 2023 12:50:03 +0100 Subject: [PATCH 005/155] Update version to match tag --- algorithmic_efficiency/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/__init__.py b/algorithmic_efficiency/__init__.py index af0a6b8fc..a0e473e1d 100644 --- a/algorithmic_efficiency/__init__.py +++ b/algorithmic_efficiency/__init__.py @@ -1,3 +1,3 @@ """Algorithmic Efficiency.""" -__version__ = '0.0.1' +__version__ = '0.1.0' From 386fabb704ad806e71ca2d40b3527fbf4880b291 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 13 Dec 2023 12:50:37 +0100 Subject: [PATCH 006/155] Restructure and fix ogbg data_dir --- datasets/README.md | 155 +++++++++++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 62 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 614344978..c3feb5fed 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -1,74 +1,111 @@ -# Dataset Setup -TL;DR: -Use `dataset_setup.py` to download datasets. -Usage: +# MLCommonsâ„¢ AlgoPerf: Dataset Setup + +## Table of Contents + +- [General Setup](#general-setup) + - [Set Data Directory (Docker Container)](#set-data-directory-docker-container) + - [Set Data Directory (on Host)](#set-data-directory-on-host) + - [Start tmux session (Recommended)](#start-tmux-session-recommended) + - [Clean up](#clean-up) +- [Individual Dataset Instructions](#individual-dataset-instructions) + - [OGBG](#ogbg) + - [WMT](#wmt) + - [FastMRI](#fastmri) + - [ImageNet](#imagenet) + - [Criteo1TB](#criteo1tb) + - [LibriSpeech](#librispeech) + - [Training SPM Tokenizer](#training-spm-tokenizer) + - [Preprocessing Script](#preprocessing-script) + +## General Setup + +This document provides instructions on downloading and preparing all datasets utilized in the AlgoPerf benchmark. You can prepare the individual datasets one-by-one as needed. If your setup, such as your cloud or cluster environment, already contains these datasets, you may skip the dataset setup for this particular data (and directly specify the dataset location in the `submission_runner.py`). Just verify that you are using the same dataset version (and possible preprocessing). + +*TL;DR to download and prepare a dataset, run `dataset_setup.py`:* + ```bash python3 datasets/dataset_setup.py \ --data_dir=~/data \ -- - -- + -- ``` -The complete benchmark uses 6 datasets: -- OGBG -- WMT -- FastMRI -- Imagenet -- Criteo 1TB -- Librispeech +The complete benchmark uses 6 different datasets: + +- [OGBG](#ogbg) +- [WMT](#wmt) +- [FastMRI](#fastmri) +- [Imagenet](#imagenet) +- [Criteo 1TB](#criteo1tb) +- [Librispeech](#librispeech) -Some dataset setups will require you to sign a third party agreement with the dataset owners in order to get the donwload URLs. +Some dataset setups will require you to sign a third-party agreement with the dataset owners in order to get the download URLs. -# Per dataset instructions -## Environment +### Set Data Directory (Docker Container) -### Set data directory (Docker container) -If you are running the `dataset_setup.py` script from a Docker container, please +If you are running the `dataset_setup.py` script from a Docker container, please make sure the data directory is mounted to a directory on your host with --v flag. If you are following instructions from the README you will have used +`-v` flag. If you are following instructions from the [Getting Started guide](/GETTING_STARTED.md) you will have used the `-v $HOME/data:/data` flag in the `docker run` command. This will mount -the `$HOME/data` directory to the `/data` directory in the container. -In this case set --data_dir to `/data`. +the `$HOME/data` directory to the `/data` directory in the container. +In this case set, `--data_dir` to `/data`. + ```bash DATA_DIR='/data' ``` -### Set data directory (on host) -Alternatively, if you are running the data download script directly on your host, feel free -to choose whatever directory you find suitable, further submission instructions -assume the data is stored in `~/data`. + +### Set Data Directory (on Host) + +Alternatively, if you are running the data download script directly on your host, feel free to choose whatever directory you find suitable, further submission instructions assume the data is stored in `~/data`. + ```bash DATA_DIR='~/data' ``` + #### Start tmux session (Recommended) -If running the dataset_setup.py on directly on host it is recommended to run -the dataset_setup.py script in a tmux session because some of the data downloads may -take several hours. To avoid your setup being interrupted start a tmux session: + +If running the `dataset_setup.py` on directly on host it is recommended to run +the `dataset_setup.py` script in a `tmux` session because some of the data downloads may take several hours. To avoid your setup being interrupted start a `tmux` session: + ```bash tmux new -s data_setup ``` +### Clean up + +In order to avoid potential accidental deletion, this script does NOT +delete any intermediate temporary files (such as zip archives) without a user +confirmation. Deleting temp files is particularly important for Criteo 1TB, as +there can be multiple copies of the dataset on disk during preprocessing if +files are not cleaned up. + +By default, a user will be prompted before any files are deleted. If you do not want any temp files to be deleted, you can pass `--interactive_deletion=false` and then all files will be downloaded to the provided `--temp_dir`, and the user can manually delete these after downloading has finished. -## Datasets +## Individual Dataset Instructions + +### OGBG -### OGBG From `algorithmic-efficiency` run: + ```bash python3 datasets/dataset_setup.py \ ---data_dir $DATA_DIR/ogbg \ +--data_dir $DATA_DIR \ --ogbg ``` -### WMT +### WMT + From `algorithmic-efficiency` run: + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ --wmt ``` +### FastMRI -## FastMRI -Fill out form on https://fastmri.med.nyu.edu/. After filling out the form +Fill out form on . After filling out the form you should get an email containing the URLS for "knee_singlecoil_train", "knee_singlecoil_val" and "knee_singlecoil_test". @@ -81,18 +118,14 @@ python3 datasets/dataset_setup.py \ --fastmri_knee_singlecoil_test_url '' ``` -## ImageNet -Register on https://image-net.org/ and follow directions to obtain the -URLS for the ILSVRC2012 train and validation images. +### ImageNet -Imagenet dataset processsing is resource intensive. To avoid potential -ResourcExhausted errors increase the maximum number of open file descriptors: -```bash -ulimit -n 8192 -``` +Register on and follow directions to obtain the +URLS for the ILSVRC2012 train and validation images. +The script will additionally automatically download the `matched-frequency` version of [ImageNet v2](https://www.tensorflow.org/datasets/catalog/imagenet_v2#imagenet_v2matched-frequency_default_config), which is used as the test set of the ImageNet workloads. -The imagenet data pipeline differs between the pytorch and jax workloads. -Therefore, you will have to specify the framework (pytorch or jax) through theframework flag. +The ImageNet data pipeline differs between the PyTorch and JAX workloads. +Therefore, you will have to specify the framework (either `pytorch` or `jax`) through the framework flag. ```bash python3 datasets/dataset_setup.py \ @@ -102,15 +135,22 @@ python3 datasets/dataset_setup.py \ --imagenet_train_url \ --imagenet_val_url \ --framework jax +``` +Imagenet dataset processsing is resource intensive. To avoid potential +ResourcExhausted errors increase the maximum number of open file descriptors: + +```bash +ulimit -n 8192 ``` -Note that some functions use subprocess.Popen(..., shell=True), which can be -dangerous if the user injects code into the --data_dir or --temp_dir flags. We -do some basic sanitization in main(), but submitters should not let untrusted +Note that some functions use `subprocess.Popen(..., shell=True)`, which can be +dangerous if the user injects code into the `--data_dir` or `--temp_dir` flags. We +do some basic sanitization in `main()`, but submitters should not let untrusted users run this script on their systems. -## Criteo1tb +### Criteo1TB + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ @@ -118,19 +158,10 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` -### Clean up -In order to avoid potential accidental deletion, this script does NOT -delete any intermediate temporary files (such as zip archives) without a user -confirmation. Deleting temp files is particularly important for Criteo 1TB, as -there can be multiple copies of the dataset on disk during preprocessing if -files are not cleaned up. If you do not want any temp files to be deleted, you -can pass --interactive_deletion=false and then all files will be downloaded to -the provided --temp_dir, and the user can manually delete these after -downloading has finished. +### LibriSpeech - -## Librispeech To download, train a tokenizer and preprocess the librispeech dataset: + ```bash python3 datasets/dataset_setup.py \ --data_dir $DATA_DIR \ @@ -138,26 +169,26 @@ python3 datasets/dataset_setup.py \ --librispeech ``` -### Notes on librispeech preprocessing #### Training SPM Tokenizer + A simple sentence piece tokenizer is trained over librispeech training data. This tokenizer is then used in later preprocessing step to tokenize transcripts. This command generates `spm_model.vocab` file in `$DATA_DIR/librispeech`: + ```bash python3 librispeech_tokenizer.py --train --data_dir=$DATA_DIR/librispeech ``` The trained tokenizer can be loaded back to do sanity check by tokenizing + de-tokenizing a constant string: + ```bash librispeech_tokenizer.py --data_dir=$DATA_DIR/librispeech ``` #### Preprocessing Script + The preprocessing script will generate `.npy` files for audio data, `features.csv` which has paths to saved audio `.npy`, and `trans.csv` which has paths to `features.csv` and transcription data. ```bash python3 librispeech_preprocess.py --data_dir=$DATA_DIR/librispeech --tokenizer_vocab_path=$DATA_DIR/librispeech/spm_model.vocab ``` - - - From 23320d2d1c5c140e93c1dfa59d28a073a80a3f4c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 13 Dec 2023 12:58:13 +0100 Subject: [PATCH 007/155] Standardize how subfolders for datasets are implemented --- datasets/dataset_setup.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f9ee2f138..f765e4a1a 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -334,6 +334,7 @@ def download_criteo1tb(data_dir, def download_cifar(data_dir, framework): + data_dir = os.path.join(data_dir, 'cifar10') if framework == 'jax': tfds.builder('cifar10:3.0.2', data_dir=data_dir).download_and_prepare() elif framework == 'pytorch': @@ -398,18 +399,18 @@ def extract(source, dest, mode='r:xz'): def setup_fastmri(data_dir, src_data_dir): + data_dir = os.path.join(data_dir, 'fastmri') train_tar_file_path = os.path.join(src_data_dir, FASTMRI_TRAIN_TAR_FILENAME) val_tar_file_path = os.path.join(src_data_dir, FASTMRI_VAL_TAR_FILENAME) test_tar_file_path = os.path.join(src_data_dir, FASTMRI_TEST_TAR_FILENAME) # Make train, val and test subdirectories - fastmri_data_dir = os.path.join(data_dir, 'fastmri') - train_data_dir = os.path.join(fastmri_data_dir, 'train') + train_data_dir = os.path.join(data_dir, 'train') os.makedirs(train_data_dir, exist_ok=True) - val_data_dir = os.path.join(fastmri_data_dir, 'val') + val_data_dir = os.path.join(data_dir, 'val') os.makedirs(val_data_dir, exist_ok=True) - test_data_dir = os.path.join(fastmri_data_dir, 'test') + test_data_dir = os.path.join(data_dir, 'test') os.makedirs(test_data_dir, exist_ok=True) # Unzip tar file into subdirectories @@ -425,6 +426,7 @@ def setup_fastmri(data_dir, src_data_dir): def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): """Downloads and returns the download dir.""" + data_dir = os.path.join(data_dir, 'imagenet') imagenet_train_filepath = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME) imagenet_val_filepath = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME) @@ -456,6 +458,7 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): def setup_imagenet(data_dir, framework=None): + data_dir = os.path.join(data_dir, 'imagenet') if framework == 'jax': setup_imagenet_jax(data_dir) @@ -629,6 +632,7 @@ def download_librispeech(dataset_dir, tmp_dir): def download_mnist(data_dir): + data_dir = os.path.join(data_dir, 'MNIST') # Capitalization to match PyTorch tfds.builder('mnist', data_dir=data_dir).download_and_prepare() @@ -714,9 +718,8 @@ def main(_): raise ValueError( 'Please specify either jax or pytorch framework through framework ' 'flag.') - imagenet_data_dir = os.path.join(data_dir, 'imagenet') - download_imagenet(imagenet_data_dir, imagenet_train_url, imagenet_val_url) - setup_imagenet(imagenet_data_dir, framework=FLAGS.framework) + download_imagenet(data_dir, imagenet_train_url, imagenet_val_url) + setup_imagenet(data_dir, framework=FLAGS.framework) if FLAGS.all or FLAGS.librispeech: logging.info('Downloading Librispeech...') From c46c548976c061aa032b742d320303c3dc24c235 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 13 Dec 2023 13:53:50 +0100 Subject: [PATCH 008/155] Add resulting directory structures and file numbers/sizes --- datasets/README.md | 196 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/datasets/README.md b/datasets/README.md index c3feb5fed..4f7b6b880 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -93,6 +93,32 @@ python3 datasets/dataset_setup.py \ --ogbg ``` +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── ogbg +│ └── ogbg_molpcba +│ └── 0.1.3 +│ ├── dataset_info.json +│ ├── features.json +│ ├── metadata.json +│ ├── ogbg_molpcba-test.tfrecord-00000-of-00001 +│ ├── ogbg_molpcba-train.tfrecord-00000-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00001-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00002-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00003-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00004-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00005-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00006-of-00008 +│ ├── ogbg_molpcba-train.tfrecord-00007-of-00008 +│ └── ogbg_molpcba-validation.tfrecord-00000-of-00001 +``` + +In total, it should contain 13 files (via `find -type f | wc -l`) for a total of 777 MB (via `du -sch ogbg/`). +
+ ### WMT From `algorithmic-efficiency` run: @@ -103,6 +129,64 @@ python3 datasets/dataset_setup.py \ --wmt ``` +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── wmt + ├── wmt14_translate + │ └── de-en + │ └── 1.0.0 + │ ├── dataset_info.json + │ ├── features.json + │ ├── wmt14_translate-test.tfrecord-00000-of-00001 + │ ├── wmt14_translate-train.tfrecord-00000-of-00016 + │ ├── wmt14_translate-train.tfrecord-00001-of-00016 + │ ├── wmt14_translate-train.tfrecord-00002-of-00016 + │ ├── wmt14_translate-train.tfrecord-00003-of-00016 + │ ├── wmt14_translate-train.tfrecord-00004-of-00016 + │ ├── wmt14_translate-train.tfrecord-00005-of-00016 + │ ├── wmt14_translate-train.tfrecord-00006-of-00016 + │ ├── wmt14_translate-train.tfrecord-00007-of-00016 + │ ├── wmt14_translate-train.tfrecord-00008-of-00016 + │ ├── wmt14_translate-train.tfrecord-00009-of-00016 + │ ├── wmt14_translate-train.tfrecord-00010-of-00016 + │ ├── wmt14_translate-train.tfrecord-00011-of-00016 + │ ├── wmt14_translate-train.tfrecord-00012-of-00016 + │ ├── wmt14_translate-train.tfrecord-00013-of-00016 + │ ├── wmt14_translate-train.tfrecord-00014-of-00016 + │ ├── wmt14_translate-train.tfrecord-00015-of-00016 + │ └── wmt14_translate-validation.tfrecord-00000-of-00001 + ├── wmt17_translate + │ └── de-en + │ └── 1.0.0 + │ ├── dataset_info.json + │ ├── features.json + │ ├── wmt17_translate-test.tfrecord-00000-of-00001 + │ ├── wmt17_translate-train.tfrecord-00000-of-00016 + │ ├── wmt17_translate-train.tfrecord-00001-of-00016 + │ ├── wmt17_translate-train.tfrecord-00002-of-00016 + │ ├── wmt17_translate-train.tfrecord-00003-of-00016 + │ ├── wmt17_translate-train.tfrecord-00004-of-00016 + │ ├── wmt17_translate-train.tfrecord-00005-of-00016 + │ ├── wmt17_translate-train.tfrecord-00006-of-00016 + │ ├── wmt17_translate-train.tfrecord-00007-of-00016 + │ ├── wmt17_translate-train.tfrecord-00008-of-00016 + │ ├── wmt17_translate-train.tfrecord-00009-of-00016 + │ ├── wmt17_translate-train.tfrecord-00010-of-00016 + │ ├── wmt17_translate-train.tfrecord-00011-of-00016 + │ ├── wmt17_translate-train.tfrecord-00012-of-00016 + │ ├── wmt17_translate-train.tfrecord-00013-of-00016 + │ ├── wmt17_translate-train.tfrecord-00014-of-00016 + │ ├── wmt17_translate-train.tfrecord-00015-of-00016 + │ └── wmt17_translate-validation.tfrecord-00000-of-00001 + └── wmt_sentencepiece_model +``` + +In total, it should contain 43 files (via `find -type f | wc -l`) for a total of 3.3 GB (via `du -sch wmt/`). +
+ ### FastMRI Fill out form on . After filling out the form @@ -118,6 +202,29 @@ python3 datasets/dataset_setup.py \ --fastmri_knee_singlecoil_test_url '' ``` +
+The final directory structure should look like this: + +```bash +$DATA_DIR +├── fastmri +│ ├── knee_singlecoil_test +│ │ ├── file1000022.h5 +│ │ ├── [...] +│ │ └── file1002571.h5 +│ ├── knee_singlecoil_train +│ │ ├── file1000001.h5 +│ │ ├── [...] +│ │ └── file1002569.h5 +│ └── knee_singlecoil_val +│ ├── file1000000.h5 +│ ├── [...] +│ └── file1002570.h5 +``` + +In total, it should contain 1280 files (via `find -type f | wc -l`) for a total of 112 GB (via `du -sch fastmri/`). +
+ ### ImageNet Register on and follow directions to obtain the @@ -149,6 +256,73 @@ dangerous if the user injects code into the `--data_dir` or `--temp_dir` flags. do some basic sanitization in `main()`, but submitters should not let untrusted users run this script on their systems. +
+The final directory structure should look like this for ImageNet2012 (PyTorch): + +```bash +$DATA_DIR +├── imagenet +│ ├── train +│ ├── n01440764 +│ ├── n01440764_10026.JPEG +│ ├── n01440764_10027.JPEG +│ ├── n01440764_10029.JPEG +│ ├── [...] +│ ├── [...] +│ └── val +│ ├── n01440764 +│ ├── ILSVRC2012_val_00000293.JPEG +│ ├── ILSVRC2012_val_00002138.JPEG +│ ├── [...] +│ ├── [...] +``` + +In total, it should contain 1,281,167 `train` files and 50,000 `val` (via `find -type f | wc -l`) for a total of 177 GB and 7.8 GB, respectively (via `du -sch train/` and `du -sch val/`). +
+ +**TODO** +
+The final directory structure should look like this for ImageNet2012 (JAX): + +```bash +$DATA_DIR +``` + +In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch imagenet/`). +
+ +
+The final directory structure should look like this for ImageNet v2: + +```bash +$DATA_DIR +├── imagenet_v2 +│ └── matched-frequency +│ └── 3.0.0 +│ ├── dataset_info.json +│ ├── features.json +│ ├── imagenet_v2-test.tfrecord-00000-of-00016 +│ ├── imagenet_v2-test.tfrecord-00001-of-00016 +│ ├── imagenet_v2-test.tfrecord-00002-of-00016 +│ ├── imagenet_v2-test.tfrecord-00003-of-00016 +│ ├── imagenet_v2-test.tfrecord-00004-of-00016 +│ ├── imagenet_v2-test.tfrecord-00005-of-00016 +│ ├── imagenet_v2-test.tfrecord-00006-of-00016 +│ ├── imagenet_v2-test.tfrecord-00007-of-00016 +│ ├── imagenet_v2-test.tfrecord-00008-of-00016 +│ ├── imagenet_v2-test.tfrecord-00009-of-00016 +│ ├── imagenet_v2-test.tfrecord-00010-of-00016 +│ ├── imagenet_v2-test.tfrecord-00011-of-00016 +│ ├── imagenet_v2-test.tfrecord-00012-of-00016 +│ ├── imagenet_v2-test.tfrecord-00013-of-00016 +│ ├── imagenet_v2-test.tfrecord-00014-of-00016 +│ ├── imagenet_v2-test.tfrecord-00015-of-00016 +│ └── label.labels.txt +``` + +In total, it should contain 20 files (via `find -type f | wc -l`) for a total of 1.2 GB (via `du -sch imagenet_v2/`). +
+ ### Criteo1TB ```bash @@ -158,6 +332,17 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` +**TODO** +
+The final directory structure should look like this: + +```bash +$DATA_DIR +``` + +In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch criteo1tb/`). +
+ ### LibriSpeech To download, train a tokenizer and preprocess the librispeech dataset: @@ -169,6 +354,17 @@ python3 datasets/dataset_setup.py \ --librispeech ``` +**TODO** +
+The final directory structure should look like this: + +```bash +$DATA_DIR +``` + +In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch librispeech/`). +
+ #### Training SPM Tokenizer A simple sentence piece tokenizer is trained over librispeech training From 8453784de7e8bc4a24669276befc75b562a6e0a2 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Fri, 15 Dec 2023 13:02:49 +0100 Subject: [PATCH 009/155] Highlight next deadline --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 3e888a0a9..941344903 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ > *AlgoPerf* is a suite of benchmarks and competitions to measure neural network training speedups due to algorithmic improvements in both training algorithms and models. This is the repository for the *AlgoPerf: Training Algorithms benchmark* and its associated competition. It is developed by the [MLCommons Algorithms Working Group](https://mlcommons.org/en/groups/research-algorithms/). This repository holds the [**competition rules**](/COMPETITION_RULES.md), the [**technical documentation**](/DOCUMENTATION.md) of the benchmark, [**getting started guides**](/GETTING_STARTED.md), and the benchmark code. For a detailed description of the benchmark design, see our [**paper**](https://arxiv.org/abs/2306.07179). +--- + +> [!IMPORTANT] +> Upcoming Deadline: +> Registration deadline to express non-binding intent to submit: **January 28th, 2024** + ## Table of Contents - [Installation](#installation) From 1fdd724cb91e9355336452266ffd3e9619b1d840 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 11:30:09 +0100 Subject: [PATCH 010/155] Add missing directory structures --- datasets/README.md | 68 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 7 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 4f7b6b880..ce2a6390e 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -280,19 +280,38 @@ $DATA_DIR In total, it should contain 1,281,167 `train` files and 50,000 `val` (via `find -type f | wc -l`) for a total of 177 GB and 7.8 GB, respectively (via `du -sch train/` and `du -sch val/`). -**TODO**
The final directory structure should look like this for ImageNet2012 (JAX): ```bash $DATA_DIR +├──imagenet +│ ├── jax +│ │ ├── downloads +│ │ │ ├── extracted +│ │ │ └── manual_ +│ │ ├── imagenet2012 +│ │ │ └── 5.1.0 +│ │ │ ├── dataset_info.json +│ │ │ ├── features.json +│ │ │ ├── imagenet2012-train.tfrecord-00000-of-01024 +│ │ │ ├── imagenet2012-train.tfrecord-00001-of-01024 +│ │ │ ├── [...] +│ │ └── imagenet_v2 +│ │ └── matched-frequency +│ │ └── 3.0.0 +│ │ ├── dataset_info.json +│ │ ├── features.json +│ │ ├── imagenet_v2-test.tfrecord-00000-of-00016 +│ │ ├── imagenet_v2-test.tfrecord-00001-of-00016 +│ │ ├── [...] ``` -In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch imagenet/`). +In total, it should contain 1,111 files (via `find -type f | wc -l`) for a total of 145 GB (via `du -sch imagenet/jax`).
-The final directory structure should look like this for ImageNet v2: +The final directory structure should look like this for ImageNet v2 (separate): ```bash $DATA_DIR @@ -332,15 +351,20 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` -**TODO**
The final directory structure should look like this: ```bash $DATA_DIR +├── criteo1tb +│ ├── day_0_000.csv +│ ├── day_0_001.csv +│ ├── day_0_002.csv +│ ├── day_0_003.csv +│ ├── [...] ``` -In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch criteo1tb/`). +In total, it should contain 885 files (via `find -type f | wc -l`) for a total of 1.1 TB (via `du -sch criteo1tb/`).
### LibriSpeech @@ -354,15 +378,45 @@ python3 datasets/dataset_setup.py \ --librispeech ``` -**TODO**
The final directory structure should look like this: ```bash $DATA_DIR +├──librispeech +│ ├── dev-clean +│ │ ├── 1272-128104-0000_audio.npy +│ │ ├── 1272-128104-0000_targets.npy +│ │ ├── [...] +│ ├── dev-clean.csv +│ ├── dev-other +│ │ ├── 116-288045-0000_audio.npy +│ │ ├── 116-288045-0000_targets.npy +│ │ ├── [...] +│ ├── dev-other.csv +│ ├── spm_model.vocab +│ ├── test-clean +│ │ ├── 1089-134686-0000_audio.npy +│ │ ├── 1089-134686-0000_targets.npy +│ │ ├── [...] +│ ├── test-clean.csv +│ ├── train-clean-100 +│ │ ├── 103-1240-0000_audio.npy +│ │ ├── 103-1240-0000_targets.npy +│ │ ├── [...] +│ ├── train-clean-100.csv +│ ├── train-clean-360 +│ │ ├── 100-121669-0000_audio.npy +│ │ ├── 100-121669-0000_targets.npy +│ │ ├── [...] +│ ├── train-clean-360.csv +│ │ ├── 985-126228-0050_audio.npy +│ │ └── 985-126228-0050_targets.npy +│ │ ├── [...] +│ └── train-other-500.csv ``` -In total, it should contain ?? files (via `find -type f | wc -l`) for a total of ?? GB (via `du -sch librispeech/`). +In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 338 GB (via `du -sch librispeech/`).
#### Training SPM Tokenizer From b3b0785458ddb0ed38c25945a744530930637f06 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 11:30:21 +0100 Subject: [PATCH 011/155] Add download and disk sizes --- datasets/dataset_setup.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f765e4a1a..9140ed18a 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -26,17 +26,19 @@ Criteo 1TB download size: ~350GB Criteo 1TB final disk size: ~1TB -FastMRI download size: -FastMRI final disk size: -LibriSpeech download size: -LibriSpeech final disk size: -OGBG download size: -OGBG final disk size: -WMT download size: (1.58 GiB + ) = -WMT final disk size: +FastMRI download size: ~90GB +FastMRI final disk size: ~110GB +ImageNet download size: ~150GB +ImageNet final disk size: ~150GB +LibriSpeech download size: ~60GB +LibriSpeech final disk size: ~350GB +OGBG download size: ~37MB +OGBG final disk size: ~800MB +WMT download size: ~3GB +WMT final disk size: ~3GB _______________________ -Total download size: -Total disk size: +Total download size: ~650GB +Total disk size: ~1.1TB Some datasets require signing a form before downloading: @@ -49,8 +51,8 @@ Register on https://image-net.org/ and run this script with the links to the ILSVRC2012 train and validation images. -Note for tfds ImageNet, you may have to increase the max number of files allowed -open at once using `ulimit -n 8192`. +Note for tfds ImageNet, you may have to increase the max number of files +allowed open at once using `ulimit -n 8192`. Example command: From e1fd0f1e22c72f8d099cb52f3647ce04a9606f89 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 11:30:35 +0100 Subject: [PATCH 012/155] Remove unused import --- datasets/dataset_setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 9140ed18a..2ddbf4438 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -79,7 +79,6 @@ import functools import os -import resource import shutil import subprocess import tarfile From efa518572531ecbe958158794b256063f6557939 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 12:46:32 +0100 Subject: [PATCH 013/155] fix fastmri dir structure and simplify --- datasets/dataset_setup.py | 46 +++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2ddbf4438..7638cd1b6 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -399,31 +399,35 @@ def extract(source, dest, mode='r:xz'): tar.close() -def setup_fastmri(data_dir, src_data_dir): - data_dir = os.path.join(data_dir, 'fastmri') - - train_tar_file_path = os.path.join(src_data_dir, FASTMRI_TRAIN_TAR_FILENAME) - val_tar_file_path = os.path.join(src_data_dir, FASTMRI_VAL_TAR_FILENAME) - test_tar_file_path = os.path.join(src_data_dir, FASTMRI_TEST_TAR_FILENAME) - - # Make train, val and test subdirectories - train_data_dir = os.path.join(data_dir, 'train') - os.makedirs(train_data_dir, exist_ok=True) - val_data_dir = os.path.join(data_dir, 'val') - os.makedirs(val_data_dir, exist_ok=True) - test_data_dir = os.path.join(data_dir, 'test') - os.makedirs(test_data_dir, exist_ok=True) +def setup_fastmri(data_dir): + train_tar_file_path = os.path.join(data_dir, FASTMRI_TRAIN_TAR_FILENAME) + val_tar_file_path = os.path.join(data_dir, FASTMRI_VAL_TAR_FILENAME) + test_tar_file_path = os.path.join(data_dir, FASTMRI_TEST_TAR_FILENAME) # Unzip tar file into subdirectories - logging.info('Unzipping {} to {}'.format(train_tar_file_path, train_data_dir)) - extract(train_tar_file_path, train_data_dir) - logging.info('Unzipping {} to {}'.format(val_tar_file_path, val_data_dir)) - extract(val_tar_file_path, val_data_dir) - logging.info('Unzipping {} to {}'.format(test_tar_file_path, test_data_dir)) - extract(test_tar_file_path, test_data_dir) - logging.info('Set up fastMRI dataset complete') + logging.info('Unzipping {} to {}'.format(train_tar_file_path, data_dir)) + extract(train_tar_file_path, data_dir) + logging.info('Unzipping {} to {}'.format(val_tar_file_path, data_dir)) + extract(val_tar_file_path, data_dir) + logging.info('Unzipping {} to {}'.format(test_tar_file_path, data_dir)) + extract(test_tar_file_path, data_dir) logging.info('Extraction completed!') + # Rename folders to match what the workload expects + os.rename( + os.path.join(data_dir, "singlecoil_train"), + os.path.join(data_dir, "knee_singlecoil_train"), + ) + os.rename( + os.path.join(data_dir, "singlecoil_val"), + os.path.join(data_dir, "knee_singlecoil_val"), + ) + os.rename( + os.path.join(data_dir, "singlecoil_test"), + os.path.join(data_dir, "knee_singlecoil_test"), + ) + logging.info("Set up fastMRI dataset complete") + def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): """Downloads and returns the download dir.""" From c4d473393f08ff5cbd878989f7cbdb3537af5352 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 12:50:47 +0100 Subject: [PATCH 014/155] Move pydub to librispeech dependency --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a00da91fc..20139d4c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,7 +67,6 @@ full = %(ogbg)s %(librispeech_conformer)s %(wmt)s - pydub==0.25.1 # All workloads plus development dependencies full_dev = @@ -98,6 +97,7 @@ ogbg = librispeech_conformer = sentencepiece==0.1.99 tensorflow-text==2.12.1 + pydub==0.25.1 wmt = sentencepiece==0.1.99 From 51e83409237168910661d875ea97b7812482ce79 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Tue, 19 Dec 2023 12:50:58 +0100 Subject: [PATCH 015/155] Note the requirement of pigz and ffmpeg --- datasets/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datasets/README.md b/datasets/README.md index ce2a6390e..685f56eed 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -351,6 +351,8 @@ python3 datasets/dataset_setup.py \ --criteo1tb ``` +Note, that this requries the [`pigz` library](https://zlib.net/pigz/) to be installed. +
The final directory structure should look like this: @@ -378,6 +380,8 @@ python3 datasets/dataset_setup.py \ --librispeech ``` +Note, that this requries the [`ffmpeg` toolbox](https://ffmpeg.org/) to be installed. +
The final directory structure should look like this: From aad8ec4b795e41f40216c0aca1ea7cb4c0815eb8 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 20 Dec 2023 21:33:57 +0000 Subject: [PATCH 016/155] fix criteo datasetting split --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f9ee2f138..755b4a93e 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -324,7 +324,7 @@ def download_criteo1tb(data_dir, unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') unzipped_paths.append(unzipped_path) split_path = os.path.join(criteo_dir, f'day_{day}_') - split_cmd = ('split -a 3 -d -l 5000000 --additional-suffix=.csv ' + split_cmd = ('split -a 2 -d -l 5000000 ' f'"{unzipped_path}" "{split_path}"') logging.info(f'Running Criteo 1TB split command:\n{split_cmd}') batch_processes.append(subprocess.Popen(split_cmd, shell=True)) From 1d6330c792fec0501bc62c24a23917cb553a78d1 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 11:36:53 +0100 Subject: [PATCH 017/155] Specified the librispeech structure --- datasets/README.md | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 685f56eed..5f2ce7504 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -388,36 +388,39 @@ Note, that this requries the [`ffmpeg` toolbox](https://ffmpeg.org/) to be insta ```bash $DATA_DIR ├──librispeech +│ ├── dev-clean.csv +│ ├── dev-other.csv +│ ├── spm_model.vocab +│ ├── test-clean.csv +│ ├── train-clean-100.csv +│ ├── train-clean-360.csv +│ ├── train-clean-500.csv │ ├── dev-clean │ │ ├── 1272-128104-0000_audio.npy │ │ ├── 1272-128104-0000_targets.npy +│ │ ├── 1272-128104-0001_audio.npy +│ │ ├── 1272-128104-0001_targets.npy │ │ ├── [...] -│ ├── dev-clean.csv │ ├── dev-other │ │ ├── 116-288045-0000_audio.npy │ │ ├── 116-288045-0000_targets.npy │ │ ├── [...] -│ ├── dev-other.csv -│ ├── spm_model.vocab │ ├── test-clean │ │ ├── 1089-134686-0000_audio.npy │ │ ├── 1089-134686-0000_targets.npy │ │ ├── [...] -│ ├── test-clean.csv │ ├── train-clean-100 │ │ ├── 103-1240-0000_audio.npy │ │ ├── 103-1240-0000_targets.npy │ │ ├── [...] -│ ├── train-clean-100.csv │ ├── train-clean-360 │ │ ├── 100-121669-0000_audio.npy │ │ ├── 100-121669-0000_targets.npy │ │ ├── [...] -│ ├── train-clean-360.csv -│ │ ├── 985-126228-0050_audio.npy -│ │ └── 985-126228-0050_targets.npy +│ ├── train-other-500 +│ │ ├── 1006-135212-0000_audio.npy +│ │ ├── 1006-135212-0000_targets.npy │ │ ├── [...] -│ └── train-other-500.csv ``` In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 338 GB (via `du -sch librispeech/`). From 0d911e0aeaabe558fb0df5d9eb2cf47b311d62a0 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 11:41:17 +0100 Subject: [PATCH 018/155] Do not process `test-other` split --- datasets/dataset_setup.py | 2 ++ datasets/librispeech_preprocess.py | 6 ++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 7638cd1b6..f52a9808e 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -601,6 +601,8 @@ def download_librispeech(dataset_dir, tmp_dir): for split in ['dev', 'test']: for version in ['clean', 'other']: + if split == 'test' and version == 'other': + continue wget_cmd = ( f'wget --directory-prefix={tmp_librispeech_dir} ' f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz') diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py index acdaa8e98..a8c5cae1d 100644 --- a/datasets/librispeech_preprocess.py +++ b/datasets/librispeech_preprocess.py @@ -31,8 +31,7 @@ 'train-clean-100': 28539, 'train-clean-360': 104014, 'train-other-500': 148688, - 'test-clean': 2620, - 'test-other': 2939, + 'test-clean': 2620, # 'test-other': 2939, 'dev-clean': 2703, 'dev-other': 2864, } @@ -153,8 +152,7 @@ def run(input_dir, output_dir, tokenizer_vocab_path): 'train-other-500', 'dev-clean', 'dev-other', - 'test-clean', - 'test-other', + 'test-clean', # 'test-other', ] for subset in subset_list: logging.info('Processing split = %s...', subset) From 6cf89e4e598f156da46c3cb401119582e8cb7d13 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 11:55:38 +0100 Subject: [PATCH 019/155] Store tokenizer in the right directory --- datasets/dataset_setup.py | 11 +++++++---- datasets/librispeech_tokenizer.py | 11 +++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index f52a9808e..ab9f31db5 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -588,13 +588,13 @@ def download_imagenet_v2(data_dir): data_dir=data_dir).download_and_prepare() -def download_librispeech(dataset_dir, tmp_dir): +def download_librispeech(data_dir, tmp_dir): # After extraction the result is a folder named Librispeech containing audio # files in .flac format along with transcripts containing name of audio file # and corresponding transcription. tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech') extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech') - final_data_dir = os.path.join(dataset_dir, 'librispeech') + final_data_dir = os.path.join(data_dir, 'librispeech') _maybe_mkdir(tmp_librispeech_dir) _maybe_mkdir(final_data_dir) @@ -627,10 +627,13 @@ def download_librispeech(dataset_dir, tmp_dir): f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True).communicate() - tokenizer_vocab_path = os.path.join(extracted_data_dir, 'spm_model.vocab') + tokenizer_vocab_path = os.path.join(data_dir, 'spm_model.vocab') if not os.path.exists(tokenizer_vocab_path): - librispeech_tokenizer.run(train=True, data_dir=extracted_data_dir) + librispeech_tokenizer.run( + train=True, + input_dir=extracted_data_dir, + tokenizer_vocab_path=tokenizer_vocab_path) librispeech_preprocess.run( input_dir=extracted_data_dir, diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py index e701d59d4..2f559752a 100644 --- a/datasets/librispeech_tokenizer.py +++ b/datasets/librispeech_tokenizer.py @@ -108,17 +108,16 @@ def load_tokenizer(model_filepath): return sp_tokenizer -def run(train, data_dir): - logging.info('Data dir: %s', data_dir) - vocab_path = os.path.join(data_dir, 'spm_model.vocab') - logging.info('vocab_path = ', vocab_path) +def run(train, input_dir, tokenizer_vocab_path): + logging.info('Data dir: %s', input_dir) + logging.info('vocab_path = %s', tokenizer_vocab_path) if train: logging.info('Training...') splits = ['train-clean-100'] - train_tokenizer(data_dir, splits, model_path=vocab_path) + train_tokenizer(input_dir, splits, model_path=tokenizer_vocab_path) else: - tokenizer = load_tokenizer(vocab_path) + tokenizer = load_tokenizer(tokenizer_vocab_path) test_input = 'OPEN SOURCE ROCKS' tokens = tokenizer.tokenize(test_input) detokenized = tokenizer.detokenize(tokens).numpy().decode('utf-8') From 25739febbabba6e203243bb3678fd0d870aba4ab Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 12:24:32 +0100 Subject: [PATCH 020/155] fix tokenizer folder --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index ab9f31db5..ad373dd43 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -627,7 +627,7 @@ def download_librispeech(data_dir, tmp_dir): f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True).communicate() - tokenizer_vocab_path = os.path.join(data_dir, 'spm_model.vocab') + tokenizer_vocab_path = os.path.join(final_data_dir, 'spm_model.vocab') if not os.path.exists(tokenizer_vocab_path): librispeech_tokenizer.run( From 87e26720da18980227d30f21f280dd4c8e1de357 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 21 Dec 2023 12:56:35 +0100 Subject: [PATCH 021/155] Fix the final directory structure of Criteo --- datasets/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datasets/README.md b/datasets/README.md index 5f2ce7504..37480f4f8 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -359,10 +359,10 @@ Note, that this requries the [`pigz` library](https://zlib.net/pigz/) to be inst ```bash $DATA_DIR ├── criteo1tb -│ ├── day_0_000.csv -│ ├── day_0_001.csv -│ ├── day_0_002.csv -│ ├── day_0_003.csv +│ ├── day_0_00 +│ ├── day_0_01 +│ ├── day_0_02 +│ ├── day_0_03 │ ├── [...] ``` @@ -428,8 +428,8 @@ In total, it should contain 543,323 files (via `find -type f | wc -l`) for a tot #### Training SPM Tokenizer - A simple sentence piece tokenizer is trained over librispeech training - data. This tokenizer is then used in later preprocessing step to tokenize transcripts. +During the above commands, a simple sentence piece tokenizer is trained over librispeech training data. +This tokenizer is then used in later preprocessing step to tokenize transcripts. This command generates `spm_model.vocab` file in `$DATA_DIR/librispeech`: ```bash From 7fb8124b4486274d81d500d6e2f64905034f1e2f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 12:55:01 +0000 Subject: [PATCH 022/155] add model and workload variant code for ogbg --- .../workloads/ogbg/ogbg_jax/models.py | 24 +++++++++--- .../workloads/ogbg/ogbg_jax/workload.py | 38 ++++++++++++++++++- .../workloads/ogbg/ogbg_pytorch/models.py | 26 +++++++++---- .../workloads/ogbg/ogbg_pytorch/workload.py | 36 +++++++++++++++++- .../workloads/ogbg/workload.py | 19 +++++++++- 5 files changed, 128 insertions(+), 15 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py index 358415587..0e66d2ab8 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py @@ -15,7 +15,7 @@ def make_fn(inputs): return make_fn -def _make_mlp(hidden_dims, dropout): +def _make_mlp(hidden_dims, dropout, activation_fn): """Creates a MLP with specified dimensions.""" @jraph.concatenated_args @@ -24,7 +24,7 @@ def make_fn(inputs): for dim in hidden_dims: x = nn.Dense(features=dim)(x) x = nn.LayerNorm()(x) - x = nn.relu(x) + x = activation_fn(x) x = dropout(x) return x @@ -42,6 +42,7 @@ class GNN(nn.Module): # If None, defaults to 0.1. dropout_rate: Optional[float] = 0.1 num_message_passing_steps: int = 5 + activation_fn_name: str = 'relu' @nn.compact def __call__(self, graph, train): @@ -59,11 +60,24 @@ def __call__(self, graph, train): embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding')) graph = embedder(graph) + if self.activation_fn_name == 'relu': + activation_fn = nn.relu + elif self.activation_fn_name == 'gelu': + activation_fn = nn.gelu + elif self.activation_fn_name == 'silu': + activation_fn = nn.silu + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( - update_edge_fn=_make_mlp(self.hidden_dims, dropout=dropout), - update_node_fn=_make_mlp(self.hidden_dims, dropout=dropout), - update_global_fn=_make_mlp(self.hidden_dims, dropout=dropout)) + update_edge_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + update_node_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn), + update_global_fn=_make_mlp( + self.hidden_dims, dropout=dropout, activation_fn=activation_fn)) graph = net(graph) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 009aab91a..809148631 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -25,7 +25,13 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate rng, params_rng, dropout_rng = jax.random.split(rng, 3) - self._model = models.GNN(self._num_outputs, dropout_rate=dropout_rate) + self._model = models.GNN( + self._num_outputs, + dropout_rate=dropout_rate, + activation_fn_name=self.activation_fn_name, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) init_fn = jax.jit(functools.partial(self._model.init, train=False)) fake_batch = jraph.GraphsTuple( n_node=jnp.asarray([1]), @@ -115,3 +121,33 @@ def _normalize_eval_metrics( del num_examples total_metrics = total_metrics.reduce() return {k: float(v) for k, v in total_metrics.compute().items()} + + +class OgbgGeluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'gelu' + + +class OgbgSiluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'silu' + +class OgbgModelSizeWorkload(OgbgWorkload): + + @property + def hidden_dims(self) -> Tuple[int]: + return (256, 256) + + @property + def latent_dim(self) -> int: + return 128 + + @property + def num_message_passing_steps(self) -> int: + return 5 \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 1b392753b..04c503179 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -10,7 +10,7 @@ from algorithmic_efficiency import init_utils -def _make_mlp(in_dim, hidden_dims, dropout_rate): +def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): """Creates a MLP with specified dimensions.""" layers = nn.Sequential() for dim in hidden_dims: @@ -33,7 +33,8 @@ class GNN(nn.Module): def __init__(self, num_outputs: int = 128, - dropout_rate: Optional[float] = 0.1) -> None: + dropout_rate: Optional[float] = 0.1, + activation_fn_name: str = 'relu') -> None: super().__init__() self.num_outputs = num_outputs if dropout_rate is None: @@ -42,6 +43,16 @@ def __init__(self, self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim) self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim) + if activation_fn_name == 'relu': + activation_fn = nn.ReLU + elif activation_fn_name == 'gelu': + activation_fn = nn.GeLU + elif activation_fn_name == 'silu': + activation_fn = nn.Silu + else: + raise ValueError( + f'Invalid activation function name: {self.activation_fn_name}') + graph_network_layers = [] for st in range(self.num_message_passing_steps): # Constants in in_dims are based on the requirements of the GraphNetwork. @@ -54,11 +65,12 @@ def __init__(self, graph_network_layers.append( GraphNetwork( - update_edge_fn=_make_mlp(in_dim, self.hidden_dims, dropout_rate), - update_node_fn=_make_mlp(in_dim, self.hidden_dims, dropout_rate), - update_global_fn=_make_mlp(last_in_dim, - self.hidden_dims, - dropout_rate))) + update_edge_fn=_make_mlp( + in_dim, self.hidden_dims, dropout_rate, activation_fn), + update_node_fn=_make_mlp( + in_dim, self.hidden_dims, dropout_rate, activation_fn), + update_global_fn=_make_mlp( + last_in_dim, self.hidden_dims, dropout_rate, activation_fn))) self.graph_network = nn.Sequential(*graph_network_layers) self.decoder = nn.Linear( diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index a1fbf2e8a..b2224bdec 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -144,7 +144,11 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate torch.random.manual_seed(rng[0]) - model = GNN(num_outputs=self._num_outputs, dropout_rate=dropout_rate) + model = GNN(num_outputs=self._num_outputs, + dropout_rate=dropout_rate, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -235,3 +239,33 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" del num_examples return {k: float(v) for k, v in total_metrics.compute().items()} + + +class OgbgGeluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'gelu' + + +class OgbgSiluWorkload(OgbgWorkload): + + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'silu' + +class OgbgModelSizeWorkload(OgbgWorkload): + + @property + def hidden_dims(self) -> Tuple[int]: + return (256, 256) + + @property + def latent_dim(self) -> int: + return 128 + + @property + def num_message_passing_steps(self) -> int: + return 5 diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index 7ca6ebc1e..8f3e8c122 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -3,7 +3,7 @@ import abc import itertools import math -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import jax @@ -22,6 +22,23 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'mean_average_precision' + @property + def activation_fn_name(self) -> str: + """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" + return 'relu' + + @property + def hidden_dims(self) -> Tuple[int]: + return (256,) + + @property + def latent_dim(self) -> int: + return 128 + + @property + def num_message_passing_steps(self) -> int: + return 5 + def has_reached_validation_target(self, eval_result: float) -> bool: return eval_result[ 'validation/mean_average_precision'] > self.validation_target_value From d6048500a701b27ddaf3d98b5a257f1b552e2326 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 13:02:40 +0000 Subject: [PATCH 023/155] add ogbg workload variant definitions to registry --- .../workloads/ogbg/ogbg_jax/workload.py | 3 ++- algorithmic_efficiency/workloads/workloads.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 809148631..e77194643 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -150,4 +150,5 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: - return 5 \ No newline at end of file + return 5 + \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 6d0b08cef..09ddfabfd 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -96,6 +96,15 @@ 'ogbg': { 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload' }, + 'ogbg_gelu': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload' + }, + 'ogbg_silu': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' + }, + 'ogbg_model_size': { + 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgModelSizeWorkload' + }, 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, 'wmt_post_ln': { 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN' From fdcb5aa098cbc1d0e1584c28946360506125d027 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 13:04:36 +0000 Subject: [PATCH 024/155] add ogbg variants to docker startup.sh --- docker/scripts/startup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index be14ab498..53ba3f6ba 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -119,7 +119,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init" \ "conformer_layernorm" "conformer_attention_temperature" \ "conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ - "fastmri_layernorm") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") # Set data and experiment paths ROOT_DATA_BUCKET="gs://mlcommons-data" From 94a942050108af7c5dabb515c1b4ad5a65c4a8a2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 14:35:25 +0000 Subject: [PATCH 025/155] activation fn --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 04c503179..4a7d96c13 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -1,5 +1,6 @@ # Ported to PyTorch from # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. +from functools import partial from typing import Callable, Optional, Tuple import jax.tree_util as tree @@ -46,7 +47,7 @@ def __init__(self, if activation_fn_name == 'relu': activation_fn = nn.ReLU elif activation_fn_name == 'gelu': - activation_fn = nn.GeLU + activation_fn = partial(nn.GeLU, approximate='tanh') elif activation_fn_name == 'silu': activation_fn = nn.Silu else: From 2c8a3e1cf8b897a6ee66638fd826139f523fa4f5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 14:43:36 +0000 Subject: [PATCH 026/155] add tests --- tests/modeldiffs/ogbg_gelu/__init__.py | 0 tests/modeldiffs/ogbg_gelu/compare.py | 113 +++++++++++++++++++ tests/modeldiffs/ogbg_model_size/__init__.py | 0 tests/modeldiffs/ogbg_model_size/compare.py | 113 +++++++++++++++++++ tests/modeldiffs/ogbg_silu/__init__.py | 0 tests/modeldiffs/ogbg_silu/compare.py | 113 +++++++++++++++++++ 6 files changed, 339 insertions(+) create mode 100644 tests/modeldiffs/ogbg_gelu/__init__.py create mode 100644 tests/modeldiffs/ogbg_gelu/compare.py create mode 100644 tests/modeldiffs/ogbg_model_size/__init__.py create mode 100644 tests/modeldiffs/ogbg_model_size/compare.py create mode 100644 tests/modeldiffs/ogbg_silu/__init__.py create mode 100644 tests/modeldiffs/ogbg_silu/compare.py diff --git a/tests/modeldiffs/ogbg_gelu/__init__.py b/tests/modeldiffs/ogbg_gelu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py new file mode 100644 index 000000000..f6175e99d --- /dev/null +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -0,0 +1,113 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgGeluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgGeluWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + bn = False + ln = False + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + if 'ModuleList' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'Linear' in i: + if 'NonDynamicallyQuantizableLinear' in i: + i = 'out' + else: + i = i.replace('Linear', 'Dense') + elif 'Conv1d' in i: + i = i.replace('Conv1d', 'Conv') + elif 'MHSAwithQS' in i: + i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + keys = list(sd.keys()) + out = {} + for k in keys: + new_key = k + if len(k) == 5: + _, gn_id, seq_id = k[:3] + gn_id = int(gn_id.split('_')[1]) + seq_id = int(seq_id.split('_')[1]) + if 'LayerNorm' in k[3]: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) + else: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) + elif len(k) == 2 and k[0] == 'Dense_2': + new_key = ('Dense_17', k[1]) + out[new_key] = sd[k] + + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/ogbg_model_size/__init__.py b/tests/modeldiffs/ogbg_model_size/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py new file mode 100644 index 000000000..3818598ed --- /dev/null +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -0,0 +1,113 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgModelSizeWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgModelSizeWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + bn = False + ln = False + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + if 'ModuleList' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'Linear' in i: + if 'NonDynamicallyQuantizableLinear' in i: + i = 'out' + else: + i = i.replace('Linear', 'Dense') + elif 'Conv1d' in i: + i = i.replace('Conv1d', 'Conv') + elif 'MHSAwithQS' in i: + i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + keys = list(sd.keys()) + out = {} + for k in keys: + new_key = k + if len(k) == 5: + _, gn_id, seq_id = k[:3] + gn_id = int(gn_id.split('_')[1]) + seq_id = int(seq_id.split('_')[1]) + if 'LayerNorm' in k[3]: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) + else: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) + elif len(k) == 2 and k[0] == 'Dense_2': + new_key = ('Dense_17', k[1]) + out[new_key] = sd[k] + + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/ogbg_silu/__init__.py b/tests/modeldiffs/ogbg_silu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py new file mode 100644 index 000000000..420ee9020 --- /dev/null +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -0,0 +1,113 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jraph +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ + OgbgSiluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ + OgbgSiluWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + bn = False + ln = False + for i in k: + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + if 'ModuleList' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'Linear' in i: + if 'NonDynamicallyQuantizableLinear' in i: + i = 'out' + else: + i = i.replace('Linear', 'Dense') + elif 'Conv1d' in i: + i = i.replace('Conv1d', 'Conv') + elif 'MHSAwithQS' in i: + i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items + keys = list(sd.keys()) + out = {} + for k in keys: + new_key = k + if len(k) == 5: + _, gn_id, seq_id = k[:3] + gn_id = int(gn_id.split('_')[1]) + seq_id = int(seq_id.split('_')[1]) + if 'LayerNorm' in k[3]: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) + else: + new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) + elif len(k) == 2 and k[0] == 'Dense_2': + new_key = ('Dense_17', k[1]) + out[new_key] = sd[k] + + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = dict( + n_node=torch.LongTensor([5]), + n_edge=torch.LongTensor([5]), + nodes=torch.randn(5, 9), + edges=torch.randn(5, 3), + globals=torch.randn(1, 128), + senders=torch.LongTensor(list(range(5))), + receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) + + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + graph_j = jraph.GraphsTuple(**jax_batch) + graph_p = jraph.GraphsTuple(**pyt_batch) + + jax_batch = {'inputs': graph_j} + pyt_batch = {'inputs': graph_p} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) From 609545f65246b99a290eb7622440082dcdfe0420 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:06:38 +0000 Subject: [PATCH 027/155] pytorch model ogbg fix --- .../workloads/ogbg/ogbg_pytorch/models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 4a7d96c13..f616dac6e 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -28,15 +28,18 @@ class GNN(nn.Module): The model assumes the input data is a jraph.GraphsTuple without global variables. The final prediction will be encoded in the globals. """ - latent_dim: int = 256 - hidden_dims: Tuple[int] = (256,) - num_message_passing_steps: int = 5 def __init__(self, num_outputs: int = 128, dropout_rate: Optional[float] = 0.1, - activation_fn_name: str = 'relu') -> None: + activation_fn_name: str = 'relu', + latent_dim: int = 256, + hidden_dims: Tuple[int] = (256,), + num_message_passing_steps: int = 5) -> None: super().__init__() + self.latent_dim = latent_dim + self.hidden_dims = hidden_dims + self.num_message_passing_steps = num_message_passing_steps self.num_outputs = num_outputs if dropout_rate is None: dropout_rate = 0.1 From 54d6ea6939c6cf0e13b87f096cd72c6229343196 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:20:53 +0000 Subject: [PATCH 028/155] ogbg fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index f616dac6e..0ae2c901a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -17,7 +17,7 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): for dim in hidden_dims: layers.add_module('dense', nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module('norm', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module('relu', nn.ReLU()) + layers.add_module('activation_fn', activation_fn) layers.add_module('dropout', nn.Dropout(dropout_rate)) return layers From 20b790ebfaaffffef4a4ab4bb6348c07822c105d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:39:49 +0000 Subject: [PATCH 029/155] fix ogbg --- algorithmic_efficiency/workloads/ogbg/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index 8f3e8c122..ade91b35d 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -33,7 +33,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 128 + return 256 @property def num_message_passing_steps(self) -> int: From 362819ea5583abec690524e606f63712e633093d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:42:05 +0000 Subject: [PATCH 030/155] fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 0ae2c901a..978b62428 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -17,7 +17,7 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): for dim in hidden_dims: layers.add_module('dense', nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module('norm', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module('activation_fn', activation_fn) + layers.add_module('activation_fn', activation_fn()) layers.add_module('dropout', nn.Dropout(dropout_rate)) return layers From 4d63b9697e9b67f27604dcf6e8e0b11530f40a09 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 15:58:05 +0000 Subject: [PATCH 031/155] fix ogbg variant --- algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py | 2 +- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index e77194643..65121ac7b 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -150,5 +150,5 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: - return 5 + return 3 \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index b2224bdec..c6e57b0f2 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -268,4 +268,4 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: - return 5 + return 3 From 58b0edf24829ac2c8e98a2139704b27b5402d01f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 16:14:10 +0000 Subject: [PATCH 032/155] ogbg debug --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index c6e57b0f2..102ef7b7c 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -264,7 +264,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 128 + return 256 @property def num_message_passing_steps(self) -> int: From e0d7dbfdf28895eb1115a6d02fd97f0833f3703a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 16:22:32 +0000 Subject: [PATCH 033/155] fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index 102ef7b7c..c6e57b0f2 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -264,7 +264,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 256 + return 128 @property def num_message_passing_steps(self) -> int: From 710db241e6b751164c4eede35cc8eee947734673 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 17:52:30 +0000 Subject: [PATCH 034/155] debugging --- .../workloads/ogbg/ogbg_jax/workload.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 65121ac7b..cb5dda800 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -43,6 +43,14 @@ def init_model_fn( receivers=jnp.asarray([0])) params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch) params = params['params'] + tabulate_fn = nn.tabulate( + self._model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, 'force_jupyter': False, 'width': 240 + }, + ) + print(tabulate_fn(fake_batch, train=False)) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) return jax_utils.replicate(params), None From c2e1b8b6d959b7020b28cc99671fc948e48a21a0 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 17:53:44 +0000 Subject: [PATCH 035/155] debug --- algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py | 1 + algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 1 + 2 files changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index cb5dda800..e4ee57fb7 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Tuple from flax import jax_utils +import flax.linen as nn import jax import jax.numpy as jnp import jraph diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index c6e57b0f2..aa0e7ae5e 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -157,6 +157,7 @@ def init_model_fn( model = DDP(model, device_ids=[RANK], output_device=RANK) else: model = torch.nn.DataParallel(model) + print(model) return model, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: From 84e4606b9e2cf22c20b3a19233d306075de14e42 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 18:05:59 +0000 Subject: [PATCH 036/155] debug --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index aa0e7ae5e..9c852b38d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -265,7 +265,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 128 + return 256 @property def num_message_passing_steps(self) -> int: From 57f1d7cb75478fbb32e9bbc178ad3c473ec27f2c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 2 Jan 2024 18:16:58 +0000 Subject: [PATCH 037/155] debug --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index 9c852b38d..aa0e7ae5e 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -265,7 +265,7 @@ def hidden_dims(self) -> Tuple[int]: @property def latent_dim(self) -> int: - return 256 + return 128 @property def num_message_passing_steps(self) -> int: From 503e4f07c37b9c0981481547d14a62c22a05aa2a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 17:56:35 +0000 Subject: [PATCH 038/155] debug --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 978b62428..6c104d59e 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -163,7 +163,6 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # giving us tensors of shape [num_edges, global_feat]. global_edge_attributes = tree.tree_map( lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_) - if self.update_edge_fn: edge_fn_inputs = torch.cat( [edges, sent_attributes, received_attributes, global_edge_attributes], @@ -180,6 +179,8 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # giving us tensors of shape [num_nodes, global_feat]. global_attributes = tree.tree_map( lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) + print('SHAPES') + print(nodes.shape, sent_attributes.shape, received_attributes.shape, global_attributes.shape) node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) From d6716df3bff341dd0c6c23cd60b4a0264e837613 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 17:59:38 +0000 Subject: [PATCH 039/155] debugging --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 6c104d59e..343fe9265 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -181,6 +181,8 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) print('SHAPES') print(nodes.shape, sent_attributes.shape, received_attributes.shape, global_attributes.shape) + print(senders.shape) + print(receivers.shape) node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) From 997e0e9984789f155427e64ffa7ac0edd32b99ac Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 21:42:15 +0000 Subject: [PATCH 040/155] debugging --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 343fe9265..326ba3c06 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -183,6 +183,8 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: print(nodes.shape, sent_attributes.shape, received_attributes.shape, global_attributes.shape) print(senders.shape) print(receivers.shape) + print(sum_n_node) + print(edges.shape) node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) From f9cabc5e205ba208adb7d675d75b39c15449895d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 23:14:37 +0000 Subject: [PATCH 041/155] fix --- .../workloads/ogbg/ogbg_pytorch/models.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 326ba3c06..7ec2f142a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -59,20 +59,23 @@ def __init__(self, graph_network_layers = [] for st in range(self.num_message_passing_steps): - # Constants in in_dims are based on the requirements of the GraphNetwork. + # Constants in in_dims are based on forward call of GraphNetwork: + # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: - in_dim = self.latent_dim * 3 + self.num_outputs + in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs + in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outs last_in_dim = self.latent_dim * 2 + self.num_outputs else: - in_dim = self.hidden_dims[-1] * 4 + in_dim_edge_fn = self.hidden_dims[-1] * 4 + in_dim_node_fn = self.hidden_dims[-1] * 4 last_in_dim = self.hidden_dims[-1] * 3 graph_network_layers.append( GraphNetwork( update_edge_fn=_make_mlp( - in_dim, self.hidden_dims, dropout_rate, activation_fn), + in_dim_edge_fn, self.hidden_dims, dropout_rate, activation_fn), update_node_fn=_make_mlp( - in_dim, self.hidden_dims, dropout_rate, activation_fn), + in_dim_node_fn, self.hidden_dims, dropout_rate, activation_fn), update_global_fn=_make_mlp( last_in_dim, self.hidden_dims, dropout_rate, activation_fn))) self.graph_network = nn.Sequential(*graph_network_layers) From 176f0517dbe8e991f797541d69a3e27da38e4f9d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 23:15:55 +0000 Subject: [PATCH 042/155] fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 7ec2f142a..5d8aab46d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -63,7 +63,7 @@ def __init__(self, # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs - in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outs + in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs last_in_dim = self.latent_dim * 2 + self.num_outputs else: in_dim_edge_fn = self.hidden_dims[-1] * 4 From 0e454ab2083a5c21081fe53756836f06f657bed4 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 23:45:20 +0000 Subject: [PATCH 043/155] debugging --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 5d8aab46d..9a3b4190c 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -64,7 +64,7 @@ def __init__(self, if st == 0: in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs - last_in_dim = self.latent_dim * 2 + self.num_outputs + last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs else: in_dim_edge_fn = self.hidden_dims[-1] * 4 in_dim_node_fn = self.hidden_dims[-1] * 4 From ea17ae6b46b739788c1facd72eec448f73710d4c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 3 Jan 2024 23:48:24 +0000 Subject: [PATCH 044/155] fix --- .../workloads/ogbg/ogbg_pytorch/models.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 9a3b4190c..52cb8e053 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -182,12 +182,6 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple: # giving us tensors of shape [num_nodes, global_feat]. global_attributes = tree.tree_map( lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_) - print('SHAPES') - print(nodes.shape, sent_attributes.shape, received_attributes.shape, global_attributes.shape) - print(senders.shape) - print(receivers.shape) - print(sum_n_node) - print(edges.shape) node_fn_inputs = torch.cat( [nodes, sent_attributes, received_attributes, global_attributes], dim=-1) From f6e1cb7989a68502d19a6cb192e907e5d069c901 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 00:07:51 +0000 Subject: [PATCH 045/155] fix --- tests/modeldiffs/ogbg/compare.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index f091d3d4f..1c552899b 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -20,6 +20,7 @@ def key_transform(k): new_key = [] bn = False ln = False + print("Key transform input ", k) for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i @@ -42,6 +43,7 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) + print("New key output", new_key) return tuple(new_key) From 72573f4eea6a81062dd975b8eb83bb440979d7ce Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 3 Jan 2024 20:34:31 -0500 Subject: [PATCH 046/155] Fix names --- .../workloads/imagenet_vit/imagenet_jax/workload.py | 4 ++-- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 1acd58bcd..4b12247c2 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -95,14 +95,14 @@ def use_glu(self) -> bool: return True -class ImagenetViTPostLNWorkload(ImagenetVitWorkload): +class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapLNWorkload(ImagenetVitWorkload): +class ImagenetVitMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 013bc643f..645b795ca 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -89,14 +89,14 @@ def use_glu(self) -> bool: return True -class ImagenetViTPostLNWorkload(ImagenetVitWorkload): +class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapWorkload(ImagenetVitWorkload): +class ImagenetVitMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: From c02493b6c9eab977c445336f64afa72337f512a7 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:44:01 +0000 Subject: [PATCH 047/155] test fix --- tests/modeldiffs/ogbg/compare.py | 36 +++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 1c552899b..11badf91c 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -12,31 +12,43 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgWorkload as PytWorkload + OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff +MLP_HIDDEN_DIMS = len(PyTorchWorkload.hidden_dims) def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" print("Key transform input ", k) + graph_network_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = i.split('_')[1] continue - if 'CustomBatchNorm' in i: + elif 'GraphNetwork' in i: + graph_network_index = i.split('_')[1] continue - if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'Linear' in i: + layer_index = i.split('_')[1] + if graph_network: + count = graph_index * 3 * MLP_HIDDEN_DIMS + seq_index * MLP_HIDDEN_DIMS + layer_index + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif 'LayerNorm' in i: + layer_index = i.split('_')[1] + count = graph_index * 3 * MLP_HIDDEN_DIMS + seq_index * MLP_HIDDEN_DIMS + layer_index + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') From 8cf4714adcf1f7b51ebf1ba768ce3297482c160e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:50:19 +0000 Subject: [PATCH 048/155] fix --- tests/modeldiffs/ogbg/compare.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 11badf91c..37d7094d2 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -15,9 +15,10 @@ OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff -MLP_HIDDEN_DIMS = len(PyTorchWorkload.hidden_dims) -def key_transform(k): +hidden_dims = JaxWorkload().hidden_dims + +def key_transform(k, hidden_dims): new_key = [] bn = False ln = False @@ -39,7 +40,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = i.split('_')[1] if graph_network: - count = graph_index * 3 * MLP_HIDDEN_DIMS + seq_index * MLP_HIDDEN_DIMS + layer_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -47,7 +48,7 @@ def key_transform(k): i = 'edge_embedding' elif 'LayerNorm' in i: layer_index = i.split('_')[1] - count = graph_index * 3 * MLP_HIDDEN_DIMS + seq_index * MLP_HIDDEN_DIMS + layer_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From d984e19f0c94a20a1e81117113b18863e14855c0 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:52:01 +0000 Subject: [PATCH 049/155] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 37d7094d2..5a07b8d91 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -85,7 +85,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), From d7543f084ffc21e8a509c8e9e57da2570cf0d116 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:53:14 +0000 Subject: [PATCH 050/155] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 5a07b8d91..f3356028e 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -18,7 +18,7 @@ hidden_dims = JaxWorkload().hidden_dims -def key_transform(k, hidden_dims): +def key_transform(k): new_key = [] bn = False ln = False From d0407027c654bf4ec3daa9360ec3ac8cdfa5f304 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:55:29 +0000 Subject: [PATCH 051/155] fix --- tests/modeldiffs/ogbg/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index f3356028e..36500c88b 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -25,7 +25,7 @@ def key_transform(k): graph_network = False "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" print("Key transform input ", k) - graph_network_index = 0 + graph_index = 0 seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i @@ -35,7 +35,7 @@ def key_transform(k): seq_index = i.split('_')[1] continue elif 'GraphNetwork' in i: - graph_network_index = i.split('_')[1] + graph_index = i.split('_')[1] continue elif 'Linear' in i: layer_index = i.split('_')[1] From cd15acf23fe544c047a4cd4c73a0cdf203a363d4 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:56:32 +0000 Subject: [PATCH 052/155] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 36500c88b..b760196d1 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -16,7 +16,7 @@ from tests.modeldiffs.diff import out_diff -hidden_dims = JaxWorkload().hidden_dims +hidden_dims = len(JaxWorkload().hidden_dims) def key_transform(k): new_key = [] From e4f5e0849f6cea1e917a84780046b72e0a0dbe8c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:58:21 +0000 Subject: [PATCH 053/155] fix --- tests/modeldiffs/ogbg/compare.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index b760196d1..3fa4132bf 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -32,13 +32,13 @@ def key_transform(k): ln = ln or 'LayerNorm' in i graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: - seq_index = i.split('_')[1] + seq_index = int(i.split('_')[1]) continue elif 'GraphNetwork' in i: - graph_index = i.split('_')[1] + graph_index = int(i.split('_')[1]) continue elif 'Linear' in i: - layer_index = i.split('_')[1] + layer_index = int(i.split('_')[1]) if graph_network: count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) From 9376504f58ac7afcd8e482580e7231dfe776e3ff Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 02:59:30 +0000 Subject: [PATCH 054/155] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 3fa4132bf..e7adf25de 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -47,7 +47,7 @@ def key_transform(k): elif layer_index == 1: i = 'edge_embedding' elif 'LayerNorm' in i: - layer_index = i.split('_')[1] + layer_index = int(i.split('_')[1]) count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: From e2d88101711c42b785e21acba155c682689bfd99 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:24:10 +0000 Subject: [PATCH 055/155] debugging --- tests/modeldiffs/torch2jax_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index d9264b400..c1d4ad48a 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -77,6 +77,10 @@ def key_transform(self, k_transform_fn): } def value_transform(self, v_transform_fn): + print('pytorch sd') + print(pytorch_sd.keys()) + print('jax sd') + print(jax_sd.key()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 6de5f2f1205651c2b64345e133a30de50e374a3e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:26:37 +0000 Subject: [PATCH 056/155] fix --- tests/modeldiffs/torch2jax_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index c1d4ad48a..333cda758 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -78,9 +78,9 @@ def key_transform(self, k_transform_fn): def value_transform(self, v_transform_fn): print('pytorch sd') - print(pytorch_sd.keys()) + print(self.pytorch_sd.keys()) print('jax sd') - print(jax_sd.key()) + print(self.jax_sd.key()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 0bb7a6decba1f3d2df0dc3379dbdf64fda49e147 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:28:17 +0000 Subject: [PATCH 057/155] fix --- tests/modeldiffs/torch2jax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 333cda758..a1d6503dc 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -80,7 +80,7 @@ def value_transform(self, v_transform_fn): print('pytorch sd') print(self.pytorch_sd.keys()) print('jax sd') - print(self.jax_sd.key()) + print(self.flattened_jax_model.key()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 4821cdfa8afc4fd7d991724a06d5460eafd239f3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:31:34 +0000 Subject: [PATCH 058/155] fix --- tests/modeldiffs/torch2jax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index a1d6503dc..560a071d6 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -80,7 +80,7 @@ def value_transform(self, v_transform_fn): print('pytorch sd') print(self.pytorch_sd.keys()) print('jax sd') - print(self.flattened_jax_model.key()) + print(self.flattened_jax_model.keys()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 5f19a5afd0649513dd347d5140b422b4daa943ce Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:49:39 +0000 Subject: [PATCH 059/155] fix --- tests/modeldiffs/ogbg/compare.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index e7adf25de..c9d58f658 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -17,6 +17,7 @@ hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs= JaxWorkload().num_message_passing_steps def key_transform(k): new_key = [] @@ -46,6 +47,9 @@ def key_transform(k): i = 'node_embedding' elif layer_index == 1: i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index @@ -64,20 +68,6 @@ def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - return out From 36b65a34cb6338bc09344db724ca2174d8d22ef1 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:51:30 +0000 Subject: [PATCH 060/155] fix --- tests/modeldiffs/ogbg/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index c9d58f658..95f3e7df4 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -49,7 +49,7 @@ def key_transform(k): i = 'edge_embedding' elif layer_index == 2: count = num_graphs * 3 * hidden_dims - i = + i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index From 50d479e0902ad1be2208a5415316a289d749dd41 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 03:55:45 +0000 Subject: [PATCH 061/155] fix --- tests/modeldiffs/ogbg/compare.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 95f3e7df4..7537362ff 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -66,8 +66,9 @@ def key_transform(k): def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} + for k in sd: + out[k] = sd[k] return out From 5403f5bb76c97f413f7acd05dbad533f4149b1d9 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:13:47 +0000 Subject: [PATCH 062/155] fix tests --- tests/modeldiffs/ogbg_gelu/compare.py | 62 +++++++++++---------- tests/modeldiffs/ogbg_model_size/compare.py | 61 +++++++++++--------- tests/modeldiffs/ogbg_silu/compare.py | 62 +++++++++++---------- 3 files changed, 102 insertions(+), 83 deletions(-) diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index f6175e99d..f58c58bde 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -16,53 +16,59 @@ from tests.modeldiffs.diff import out_diff +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" + print("Key transform input ", k) + graph_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) continue - if 'CustomBatchNorm' in i: + elif 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) continue - if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') else: i = i.replace('weight', 'kernel') new_key.append(i) + print("New key output", new_key) return tuple(new_key) def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - + for k in sd: + out[k] = sd[k] return out @@ -70,7 +76,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), @@ -110,4 +116,4 @@ def sd_transform(sd): pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None) \ No newline at end of file diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 3818598ed..4df4d67aa 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -16,53 +16,59 @@ from tests.modeldiffs.diff import out_diff +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" + print("Key transform input ", k) + graph_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) continue - if 'CustomBatchNorm' in i: + elif 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) continue - if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') else: i = i.replace('weight', 'kernel') new_key.append(i) + print("New key output", new_key) return tuple(new_key) def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - + for k in sd: + out[k] = sd[k] return out @@ -70,7 +76,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), @@ -111,3 +117,4 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=None) + \ No newline at end of file diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 420ee9020..5fa9eabc9 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -16,53 +16,59 @@ from tests.modeldiffs.diff import out_diff +hidden_dims = len(JaxWorkload().hidden_dims) +num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False + graph_network = False + "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" + print("Key transform input ", k) + graph_index = 0 + seq_index = 0 for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - if 'ModuleList' in i: + graph_network = graph_network or 'GraphNetwork' in i + if 'Sequential' in i: + seq_index = int(i.split('_')[1]) continue - if 'CustomBatchNorm' in i: + elif 'GraphNetwork' in i: + graph_index = int(i.split('_')[1]) continue - if 'Linear' in i: - if 'NonDynamicallyQuantizableLinear' in i: - i = 'out' - else: - i = i.replace('Linear', 'Dense') - elif 'Conv1d' in i: - i = i.replace('Conv1d', 'Conv') - elif 'MHSAwithQS' in i: - i = i.replace('MHSAwithQS', 'SelfAttention') + elif 'Linear' in i: + layer_index = int(i.split('_')[1]) + if graph_network: + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'Dense_' + str(count) + elif layer_index == 0: + i = 'node_embedding' + elif layer_index == 1: + i = 'edge_embedding' + elif layer_index == 2: + count = num_graphs * 3 * hidden_dims + i = 'Dense_' + str(count) + elif 'LayerNorm' in i: + layer_index = int(i.split('_')[1]) + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: i = i.replace('weight', 'scale') else: i = i.replace('weight', 'kernel') new_key.append(i) + print("New key output", new_key) return tuple(new_key) def sd_transform(sd): # pylint: disable=locally-disabled, modified-iterating-dict, consider-using-dict-items - keys = list(sd.keys()) out = {} - for k in keys: - new_key = k - if len(k) == 5: - _, gn_id, seq_id = k[:3] - gn_id = int(gn_id.split('_')[1]) - seq_id = int(seq_id.split('_')[1]) - if 'LayerNorm' in k[3]: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id}'), k[4]) - else: - new_key = (k[3].replace('0', f'{gn_id*3+seq_id+2}'), k[4]) - elif len(k) == 2 and k[0] == 'Dense_2': - new_key = ('Dense_17', k[1]) - out[new_key] = sd[k] - + for k in sd: + out[k] = sd[k] return out @@ -70,7 +76,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = dict( n_node=torch.LongTensor([5]), @@ -110,4 +116,4 @@ def sd_transform(sd): pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None) \ No newline at end of file From 20b8cfce440f1fc3dc743b7ed7b43114f9a7907b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:14:51 +0000 Subject: [PATCH 063/155] fix --- tests/modeldiffs/ogbg_gelu/compare.py | 2 +- tests/modeldiffs/ogbg_model_size/compare.py | 3 +-- tests/modeldiffs/ogbg_silu/compare.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index f58c58bde..027e772d5 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgGeluWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgGeluWorkload as PytWorkload + OgbgGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 4df4d67aa..4734a0d0d 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgModelSizeWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgModelSizeWorkload as PytWorkload + OgbgModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -117,4 +117,3 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=None) - \ No newline at end of file diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 5fa9eabc9..52eee4aa8 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ OgbgSiluWorkload as JaxWorkload from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - OgbgSiluWorkload as PytWorkload + OgbgSiluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff From 528d44f992e970d8285a3d75a9d910b3b186fc51 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:28:07 +0000 Subject: [PATCH 064/155] fix --- tests/modeldiffs/ogbg/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 7537362ff..03f7451dc 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index + 1 i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index + 1 i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From 3046a6c12dc09f205a3c58971e6a7d15a131b7d8 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:32:33 +0000 Subject: [PATCH 065/155] fix --- tests/modeldiffs/ogbg/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 03f7451dc..d22499636 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + 1 + count = graph_index * 3 * hidden_dims + seq_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + 1 + count = graph_index * 3 * hidden_dims + seq_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From 44cb147b3ca11e6591ee2d16d7d296722d72b66a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:34:33 +0000 Subject: [PATCH 066/155] fix --- tests/modeldiffs/ogbg_gelu/compare.py | 4 ++-- tests/modeldiffs/ogbg_model_size/compare.py | 4 ++-- tests/modeldiffs/ogbg_silu/compare.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 027e772d5..b3158d6c4 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 4734a0d0d..b7799411c 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 52eee4aa8..675eb4215 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = graph_index * 3 * hidden_dims + seq_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From 303bf1ae6b7e096876b552ea5a200e3ed8f42f40 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:46:19 +0000 Subject: [PATCH 067/155] fix --- .../workloads/ogbg/ogbg_pytorch/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 52cb8e053..e6015196a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -14,11 +14,11 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): """Creates a MLP with specified dimensions.""" layers = nn.Sequential() - for dim in hidden_dims: - layers.add_module('dense', nn.Linear(in_features=in_dim, out_features=dim)) - layers.add_module('norm', nn.LayerNorm(dim, eps=1e-6)) - layers.add_module('activation_fn', activation_fn()) - layers.add_module('dropout', nn.Dropout(dropout_rate)) + for i, dim in enumerate(hidden_dims): + layers.add_module(f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) + layers.add_module(f'activation_fn_{i}', activation_fn()) + layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) return layers From 8fa2b44b1c78a5880daac52f0460cb2a9356dfa6 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:49:17 +0000 Subject: [PATCH 068/155] fix --- tests/modeldiffs/ogbg/compare.py | 4 ++-- tests/modeldiffs/ogbg_gelu/compare.py | 4 ++-- tests/modeldiffs/ogbg_model_size/compare.py | 4 ++-- tests/modeldiffs/ogbg_silu/compare.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index d22499636..7537362ff 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index b3158d6c4..027e772d5 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index b7799411c..4734a0d0d 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 675eb4215..52eee4aa8 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -41,7 +41,7 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -52,7 +52,7 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index + count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From c3650de0610b5294ae901ca7bf8dffd7289a3539 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 04:54:17 +0000 Subject: [PATCH 069/155] fix --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index e6015196a..31a025732 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -19,6 +19,7 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) + in_dim = dim return layers From 759fc179d35f0166f3ba5caf2e26ee714b46e53b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 16:41:00 +0000 Subject: [PATCH 070/155] ogbg variant fix --- .../workloads/ogbg/ogbg_pytorch/workload.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index aa0e7ae5e..cd4b3e0a0 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -148,7 +148,8 @@ def init_model_fn( dropout_rate=dropout_rate, hidden_dims=self.hidden_dims, latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps) + num_message_passing_steps=self.num_message_passing_steps, + activation_fn_name=self.activation_fn_name) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -257,6 +258,7 @@ def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'silu' + class OgbgModelSizeWorkload(OgbgWorkload): @property From 2b758c4a69485f24647cfcdb2e55aace6cb26918 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 16:45:50 +0000 Subject: [PATCH 071/155] fix ogbg activation fn pytorch models --- algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 31a025732..458ceff48 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -51,9 +51,9 @@ def __init__(self, if activation_fn_name == 'relu': activation_fn = nn.ReLU elif activation_fn_name == 'gelu': - activation_fn = partial(nn.GeLU, approximate='tanh') + activation_fn = partial(nn.GELU, approximate='tanh') elif activation_fn_name == 'silu': - activation_fn = nn.Silu + activation_fn = nn.SiLU else: raise ValueError( f'Invalid activation function name: {self.activation_fn_name}') From 16d6ae46121c5f241f1be08ca77ba7775a1b746c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 16:56:31 +0000 Subject: [PATCH 072/155] clean up debugging statements --- .../workloads/ogbg/ogbg_jax/workload.py | 8 -------- .../workloads/ogbg/ogbg_pytorch/workload.py | 1 - tests/modeldiffs/ogbg/compare.py | 5 ++--- tests/modeldiffs/ogbg_gelu/compare.py | 5 ++--- tests/modeldiffs/ogbg_model_size/compare.py | 5 ++--- tests/modeldiffs/ogbg_silu/compare.py | 8 ++++---- tests/modeldiffs/torch2jax_utils.py | 4 ---- 7 files changed, 10 insertions(+), 26 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index e4ee57fb7..0ff7f158a 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -44,14 +44,6 @@ def init_model_fn( receivers=jnp.asarray([0])) params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch) params = params['params'] - tabulate_fn = nn.tabulate( - self._model, - jax.random.PRNGKey(0), - console_kwargs={ - 'force_terminal': False, 'force_jupyter': False, 'width': 240 - }, - ) - print(tabulate_fn(fake_batch, train=False)) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) return jax_utils.replicate(params), None diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index cd4b3e0a0..513d6a269 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -158,7 +158,6 @@ def init_model_fn( model = DDP(model, device_ids=[RANK], output_device=RANK) else: model = torch.nn.DataParallel(model) - print(model) return model, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 7537362ff..18980d9f4 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -16,16 +16,16 @@ from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False graph_network = False - "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" - print("Key transform input ", k) graph_index = 0 seq_index = 0 for i in k: @@ -60,7 +60,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print("New key output", new_key) return tuple(new_key) diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 027e772d5..0c7b1e0d4 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -16,16 +16,16 @@ from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False graph_network = False - "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" - print("Key transform input ", k) graph_index = 0 seq_index = 0 for i in k: @@ -60,7 +60,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print("New key output", new_key) return tuple(new_key) diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 4734a0d0d..022e05b94 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -16,16 +16,16 @@ from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False graph_network = False - "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" - print("Key transform input ", k) graph_index = 0 seq_index = 0 for i in k: @@ -60,7 +60,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print("New key output", new_key) return tuple(new_key) diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 52eee4aa8..feb141057 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -16,16 +16,16 @@ from tests.modeldiffs.diff import out_diff +# Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) num_graphs= JaxWorkload().num_message_passing_steps + def key_transform(k): new_key = [] bn = False ln = False graph_network = False - "Sequential_0', 'GraphNetwork_0', 'Sequential_0', 'Linear_0', 'weight'" - print("Key transform input ", k) graph_index = 0 seq_index = 0 for i in k: @@ -60,7 +60,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print("New key output", new_key) return tuple(new_key) @@ -116,4 +115,5 @@ def sd_transform(sd): pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) \ No newline at end of file + out_transform=None) + \ No newline at end of file diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 560a071d6..d9264b400 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -77,10 +77,6 @@ def key_transform(self, k_transform_fn): } def value_transform(self, v_transform_fn): - print('pytorch sd') - print(self.pytorch_sd.keys()) - print('jax sd') - print(self.flattened_jax_model.keys()) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 15fd5b113a9676410943fb5f20143577c4d84564 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 16:57:52 +0000 Subject: [PATCH 073/155] formatting --- .../workloads/ogbg/ogbg_jax/workload.py | 2 +- .../workloads/ogbg/ogbg_pytorch/models.py | 24 ++++++++++++------- .../workloads/ogbg/ogbg_pytorch/workload.py | 15 ++++++------ .../workloads/ogbg/workload.py | 2 +- algorithmic_efficiency/workloads/workloads.py | 3 ++- tests/modeldiffs/ogbg/compare.py | 5 ++-- tests/modeldiffs/ogbg_gelu/compare.py | 7 +++--- tests/modeldiffs/ogbg_model_size/compare.py | 5 ++-- tests/modeldiffs/ogbg_silu/compare.py | 6 ++--- 9 files changed, 37 insertions(+), 32 deletions(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 0ff7f158a..7201a2d90 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -139,6 +139,7 @@ def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'silu' + class OgbgModelSizeWorkload(OgbgWorkload): @property @@ -152,4 +153,3 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: return 3 - \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py index 458ceff48..d93013b87 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py @@ -15,7 +15,8 @@ def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): """Creates a MLP with specified dimensions.""" layers = nn.Sequential() for i, dim in enumerate(hidden_dims): - layers.add_module(f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)) + layers.add_module(f'dense_{i}', + nn.Linear(in_features=in_dim, out_features=dim)) layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6)) layers.add_module(f'activation_fn_{i}', activation_fn()) layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate)) @@ -64,7 +65,8 @@ def __init__(self, # specifically update_edge_fn update_node_fn and update_global_fn. if st == 0: in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs - in_dim_node_fn = self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs + in_dim_node_fn = self.latent_dim + self.hidden_dims[ + -1] * 2 + self.num_outputs last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs else: in_dim_edge_fn = self.hidden_dims[-1] * 4 @@ -73,12 +75,18 @@ def __init__(self, graph_network_layers.append( GraphNetwork( - update_edge_fn=_make_mlp( - in_dim_edge_fn, self.hidden_dims, dropout_rate, activation_fn), - update_node_fn=_make_mlp( - in_dim_node_fn, self.hidden_dims, dropout_rate, activation_fn), - update_global_fn=_make_mlp( - last_in_dim, self.hidden_dims, dropout_rate, activation_fn))) + update_edge_fn=_make_mlp(in_dim_edge_fn, + self.hidden_dims, + dropout_rate, + activation_fn), + update_node_fn=_make_mlp(in_dim_node_fn, + self.hidden_dims, + dropout_rate, + activation_fn), + update_global_fn=_make_mlp(last_in_dim, + self.hidden_dims, + dropout_rate, + activation_fn))) self.graph_network = nn.Sequential(*graph_network_layers) self.decoder = nn.Linear( diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index 513d6a269..beb518e0f 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -144,12 +144,13 @@ def init_model_fn( """aux_dropout_rate is unused.""" del aux_dropout_rate torch.random.manual_seed(rng[0]) - model = GNN(num_outputs=self._num_outputs, - dropout_rate=dropout_rate, - hidden_dims=self.hidden_dims, - latent_dim=self.latent_dim, - num_message_passing_steps=self.num_message_passing_steps, - activation_fn_name=self.activation_fn_name) + model = GNN( + num_outputs=self._num_outputs, + dropout_rate=dropout_rate, + hidden_dims=self.hidden_dims, + latent_dim=self.latent_dim, + num_message_passing_steps=self.num_message_passing_steps, + activation_fn_name=self.activation_fn_name) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -259,7 +260,7 @@ def activation_fn_name(self) -> str: class OgbgModelSizeWorkload(OgbgWorkload): - + @property def hidden_dims(self) -> Tuple[int]: return (256, 256) diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index ade91b35d..a32f385cb 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -22,7 +22,7 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'mean_average_precision' - @property + @property def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'relu' diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 09ddfabfd..a9cbec1e8 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -103,7 +103,8 @@ 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload' }, 'ogbg_model_size': { - 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgModelSizeWorkload' + 'workload_path': 'ogbg/ogbg', + 'workload_class_name': 'OgbgModelSizeWorkload' }, 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, 'wmt_post_ln': { diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 18980d9f4..53a500085 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -15,10 +15,9 @@ OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff - # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) -num_graphs= JaxWorkload().num_message_passing_steps +num_graphs = JaxWorkload().num_message_passing_steps def key_transform(k): @@ -31,7 +30,7 @@ def key_transform(k): for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - graph_network = graph_network or 'GraphNetwork' in i + graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 0c7b1e0d4..964032da7 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -15,10 +15,9 @@ OgbgGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff - # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) -num_graphs= JaxWorkload().num_message_passing_steps +num_graphs = JaxWorkload().num_message_passing_steps def key_transform(k): @@ -31,7 +30,7 @@ def key_transform(k): for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - graph_network = graph_network or 'GraphNetwork' in i + graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue @@ -115,4 +114,4 @@ def sd_transform(sd): pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) \ No newline at end of file + out_transform=None) diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 022e05b94..b90e3d8a8 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -15,10 +15,9 @@ OgbgModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff - # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) -num_graphs= JaxWorkload().num_message_passing_steps +num_graphs = JaxWorkload().num_message_passing_steps def key_transform(k): @@ -31,7 +30,7 @@ def key_transform(k): for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - graph_network = graph_network or 'GraphNetwork' in i + graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index feb141057..10bc79f57 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -15,10 +15,9 @@ OgbgSiluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff - # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) -num_graphs= JaxWorkload().num_message_passing_steps +num_graphs = JaxWorkload().num_message_passing_steps def key_transform(k): @@ -31,7 +30,7 @@ def key_transform(k): for i in k: bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - graph_network = graph_network or 'GraphNetwork' in i + graph_network = graph_network or 'GraphNetwork' in i if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue @@ -116,4 +115,3 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=None) - \ No newline at end of file From a40e0851531d3736c02cba36e69c9cc7d6e2777f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 17:11:22 +0000 Subject: [PATCH 074/155] remove unused import --- algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py | 1 - 1 file changed, 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 7201a2d90..9fc24552d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -3,7 +3,6 @@ from typing import Any, Dict, Optional, Tuple from flax import jax_utils -import flax.linen as nn import jax import jax.numpy as jnp import jraph From ea5bd5ef37aebcd56a499f14b307dadc317b2746 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 17:18:12 +0000 Subject: [PATCH 075/155] formatting --- tests/modeldiffs/ogbg/compare.py | 9 ++++++--- tests/modeldiffs/ogbg_gelu/compare.py | 9 ++++++--- tests/modeldiffs/ogbg_model_size/compare.py | 7 +++++-- tests/modeldiffs/ogbg_silu/compare.py | 9 ++++++--- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 53a500085..40b92ce4f 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -37,10 +37,12 @@ def key_transform(k): elif 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue - elif 'Linear' in i: + if 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -51,7 +53,8 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 964032da7..4c87366f3 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -37,10 +37,12 @@ def key_transform(k): elif 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue - elif 'Linear' in i: + if 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -51,7 +53,8 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index b90e3d8a8..11a74b26a 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -40,7 +40,9 @@ def key_transform(k): elif 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -51,7 +53,8 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 10bc79f57..bb47cb4be 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -37,10 +37,12 @@ def key_transform(k): elif 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue - elif 'Linear' in i: + if 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + + layer_index) i = 'Dense_' + str(count) elif layer_index == 0: i = 'node_embedding' @@ -51,7 +53,8 @@ def key_transform(k): i = 'Dense_' + str(count) elif 'LayerNorm' in i: layer_index = int(i.split('_')[1]) - count = graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index + count = ( + graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index) i = 'LayerNorm_' + str(count) elif 'weight' in i: if bn or ln: From 59156c700294383484352d1d22a65fa1d7b83e8d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 17:44:17 +0000 Subject: [PATCH 076/155] lint fix --- tests/modeldiffs/ogbg/compare.py | 2 +- tests/modeldiffs/ogbg_gelu/compare.py | 2 +- tests/modeldiffs/ogbg_model_size/compare.py | 2 +- tests/modeldiffs/ogbg_silu/compare.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 40b92ce4f..56316ba12 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -34,7 +34,7 @@ def key_transform(k): if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue - elif 'GraphNetwork' in i: + if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue if 'Linear' in i: diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 4c87366f3..b58bcd461 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -34,7 +34,7 @@ def key_transform(k): if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue - elif 'GraphNetwork' in i: + if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue if 'Linear' in i: diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 11a74b26a..f32d53171 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -34,7 +34,7 @@ def key_transform(k): if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue - elif 'GraphNetwork' in i: + if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue elif 'Linear' in i: diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index bb47cb4be..2922b7046 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -34,7 +34,7 @@ def key_transform(k): if 'Sequential' in i: seq_index = int(i.split('_')[1]) continue - elif 'GraphNetwork' in i: + if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue if 'Linear' in i: From 28b8dfb1e5f419c37b62e7281ce2b17060a6b358 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 4 Jan 2024 17:52:53 +0000 Subject: [PATCH 077/155] lint fix --- tests/modeldiffs/ogbg_model_size/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index f32d53171..62443bbb5 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -37,7 +37,7 @@ def key_transform(k): if 'GraphNetwork' in i: graph_index = int(i.split('_')[1]) continue - elif 'Linear' in i: + if 'Linear' in i: layer_index = int(i.split('_')[1]) if graph_network: count = ( From c92075e8c02e105d6a04456badd95ad0da1e10de Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 5 Jan 2024 11:34:29 -0800 Subject: [PATCH 078/155] Update README.md Fix typo for self_tuning commands. --- .../prize_qualification_baselines/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/reference_algorithms/prize_qualification_baselines/README.md index 8276887da..100555964 100644 --- a/reference_algorithms/prize_qualification_baselines/README.md +++ b/reference_algorithms/prize_qualification_baselines/README.md @@ -50,8 +50,8 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc The prize qualification baseline submissionss for jax are: -- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` +- `reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py` Example command: @@ -62,7 +62,7 @@ python3 submission_runner.py \ --experiment_dir= \ --experiment_name= \ --workload= \ - --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --submission_path=reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py \ --tuning_ruleset=self ``` @@ -70,8 +70,8 @@ python3 submission_runner.py \ The prize qualification baseline submissionss for PyTorch are: -- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` +- `reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py` Example command: @@ -82,6 +82,6 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc --experiment_dir= \ --experiment_name=t \ --workload=\ - --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --submission_path=reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py \ --tuning_ruleset=self ``` From 1fdb92f5f854c31137865afd95a55088cf2492bf Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Mon, 8 Jan 2024 07:38:03 +0000 Subject: [PATCH 079/155] vit variants comparators --- .../imagenet_vit/imagenet_jax/models.py | 2 +- tests/modeldiffs/imagenet_vit/compare.py | 26 +++++--- tests/modeldiffs/imagenet_vit_glu/__init__.py | 0 .../compare.py} | 0 tests/modeldiffs/imagenet_vit_map/__init__.py | 0 tests/modeldiffs/imagenet_vit_map/compare.py | 63 +++++++++++++++++++ .../imagenet_vit_postln/__init__.py | 0 .../compare.py} | 4 +- 8 files changed, 84 insertions(+), 11 deletions(-) create mode 100644 tests/modeldiffs/imagenet_vit_glu/__init__.py rename tests/modeldiffs/{imagenet_vit/glu_compare.py => imagenet_vit_glu/compare.py} (100%) create mode 100644 tests/modeldiffs/imagenet_vit_map/__init__.py create mode 100644 tests/modeldiffs/imagenet_vit_map/compare.py create mode 100644 tests/modeldiffs/imagenet_vit_postln/__init__.py rename tests/modeldiffs/{imagenet_vit/post_ln_compare.py => imagenet_vit_postln/compare.py} (94%) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 32e748ec7..7c9f40b1b 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -153,7 +153,7 @@ def __call__(self, x): probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, + num_heads=self.num_heads,use_bias=False, kernel_init=nn.initializers.xavier_uniform())(probe, x) y = nn.LayerNorm()(x) diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index bf7d6dfa5..e0c2506a6 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -25,24 +25,32 @@ def key_transform(k): new_key = [] bn = False attention = False + pool_head = 'MAPHead' in k[0] ln = False enc_block = False for idx, i in enumerate(k): bn = bn or 'BatchNorm' in i ln = ln or 'LayerNorm' in i - attention = attention or 'SelfAttention' in i + attention = attention or 'SelfAttention' in i or 'MultiheadAttention' in i if 'ModuleList' in i or 'Sequential' in i: continue if 'CustomBatchNorm' in i: continue if 'Linear' in i: if attention: - i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', - }[i] + if pool_head: + i = { + 'Linear_0': 'query', + 'Linear_1': 'key_value', + 'Linear_2': 'out', + }[i] + else: + i = { + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', + }[i] else: i = i.replace('Linear', 'Dense') elif 'Conv2d' in i: @@ -54,11 +62,13 @@ def key_transform(k): i = 'Transformer' elif enc_block and 'SelfAttention' in i: i = 'MultiHeadDotProductAttention_1' + elif pool_head and 'MultiheadAttention' in i: + i = i.replace('MultiheadAttention', 'MultiHeadDotProductAttention') elif enc_block and i == 'LayerNorm_1': i = 'LayerNorm_2' elif enc_block and 'MlpBlock' in i: i = 'MlpBlock_3' - elif idx == 1 and i == 'LayerNorm_0': + elif idx == 1 and i == 'LayerNorm_0' and not pool_head: i = 'encoder_layernorm' elif 'weight' in i: if bn or ln: diff --git a/tests/modeldiffs/imagenet_vit_glu/__init__.py b/tests/modeldiffs/imagenet_vit_glu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/imagenet_vit/glu_compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py similarity index 100% rename from tests/modeldiffs/imagenet_vit/glu_compare.py rename to tests/modeldiffs/imagenet_vit_glu/compare.py diff --git a/tests/modeldiffs/imagenet_vit_map/__init__.py b/tests/modeldiffs/imagenet_vit_map/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py new file mode 100644 index 000000000..e7c4c2ee8 --- /dev/null +++ b/tests/modeldiffs/imagenet_vit_map/compare.py @@ -0,0 +1,63 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetVitMapWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetVitMapWorkload as PytWorkload + + +def sd_transform(sd): + out = {} + for k in sd: + if len(k) > 2 and k[-2] == 'key_value': + chunk0, chunk1 = sd[k].chunk(2) + out[(*k[:-2], 'key', k[-1])] = chunk0 + out[(*k[:-2], 'value', k[-1])] = chunk1 + else: + out[k] = sd[k] + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + ) diff --git a/tests/modeldiffs/imagenet_vit_postln/__init__.py b/tests/modeldiffs/imagenet_vit_postln/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/imagenet_vit/post_ln_compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py similarity index 94% rename from tests/modeldiffs/imagenet_vit/post_ln_compare.py rename to tests/modeldiffs/imagenet_vit_postln/compare.py index 8bf0bef7e..e73a140f5 100644 --- a/tests/modeldiffs/imagenet_vit/post_ln_compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -11,9 +11,9 @@ from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetViTPostLNWorkload as JaxWorkload + ImagenetVitPostLNWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetViTPostLNWorkload as PytWorkload + ImagenetVitPostLNWorkload as PytWorkload sd_transform = None From cb8cab9d8edf9586254e51d07b8d0029b5860324 Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Mon, 8 Jan 2024 07:43:57 +0000 Subject: [PATCH 080/155] style fix --- .../workloads/imagenet_vit/imagenet_jax/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 7c9f40b1b..d701e89ae 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -153,7 +153,8 @@ def __call__(self, x): probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads,use_bias=False, + num_heads=self.num_heads, + use_bias=False, kernel_init=nn.initializers.xavier_uniform())(probe, x) y = nn.LayerNorm()(x) From 45402699b8bc4f9f9673569fbf9f882b7e7cdd97 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Thu, 11 Jan 2024 10:52:53 +0100 Subject: [PATCH 081/155] Fix typo (size of LibriSpeech) --- datasets/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/README.md b/datasets/README.md index 37480f4f8..c68a5cc6b 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -423,7 +423,7 @@ $DATA_DIR │ │ ├── [...] ``` -In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 338 GB (via `du -sch librispeech/`). +In total, it should contain 543,323 files (via `find -type f | wc -l`) for a total of 388 GB (via `du -sch librispeech/`).
#### Training SPM Tokenizer From 215dadc57f1ed98523a0bd9e006d722e190d195e Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Fri, 12 Jan 2024 11:17:58 +0100 Subject: [PATCH 082/155] Move registration deadline and highlight form --- CALL_FOR_SUBMISSIONS.md | 5 +++-- COMPETITION_RULES.md | 2 +- README.md | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/CALL_FOR_SUBMISSIONS.md b/CALL_FOR_SUBMISSIONS.md index 30207ac7f..84697f577 100644 --- a/CALL_FOR_SUBMISSIONS.md +++ b/CALL_FOR_SUBMISSIONS.md @@ -13,8 +13,9 @@ Submissions can compete under two hyperparameter tuning rulesets (with separate ## Dates -- **Call for submissions: November 28th, 2023** -- Registration deadline to express non-binding intent to submit: January 28th, 2024 +- Call for submissions: November 28th, 2023 +- **Registration deadline to express non-binding intent to submit: February 28th, 2024**.\ +Please fill out the (mandatory but non-binding) [**registration form**](https://forms.gle/K7ty8MaYdi2AxJ4N8). - **Submission deadline: March 28th, 2024** - **Deadline for self-reporting preliminary results: May 28th, 2024** - [tentative] Announcement of all results: July 15th, 2024 diff --git a/COMPETITION_RULES.md b/COMPETITION_RULES.md index 85f16c4cf..92e1af072 100644 --- a/COMPETITION_RULES.md +++ b/COMPETITION_RULES.md @@ -41,7 +41,7 @@ The Competition is open to English-speaking individuals and teams (made of indiv The Competition begins at 12:01am (ET) on November 28, 2023 and ends at 11:59pm (ET) on May 28, 2024, all according to Sponsor's time clock, which decisions are final (the "Competition Period"). There are several deadlines contained within the Competition Period: -- **Intention to Submit.** You must register your Intention to Submit no later than 11:59pm ET on January 28, 2024. +- **Intention to Submit.** You must register your Intention to Submit no later than 11:59pm ET on February 28, 2024. - **Submission Period.** You must complete your Submission and enter it after the Intention to Submit deadline, but no later than 11:59pm ET on March 28, 2024. - **Deadline for self-reporting results.** 11:59pm ET on May 28, 2024. diff --git a/README.md b/README.md index 941344903..06cd7200d 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,8 @@ > [!IMPORTANT] > Upcoming Deadline: -> Registration deadline to express non-binding intent to submit: **January 28th, 2024** +> Registration deadline to express non-binding intent to submit: **February 28th, 2024**.\ +> **If you consider submitting, please fill out the** (mandatory but non-binding) [**registration form**](https://forms.gle/K7ty8MaYdi2AxJ4N8). ## Table of Contents From 479835becd9be93df1737e043fdeac6faa7c3f7c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Fri, 12 Jan 2024 11:43:36 +0100 Subject: [PATCH 083/155] Move directories and correct the links baselines -> /reference_algorithms/paper_baselines /reference_algorithms/prize_qualification_baselines -> prize_qualification_baselines --- .github/workflows/linting.yml | 2 +- .github/workflows/regression_tests.yml | 32 +++++++++---------- .../workflows/regression_tests_variants.yml | 10 +++--- COMPETITION_RULES.md | 2 +- CONTRIBUTING.md | 4 +-- DOCUMENTATION.md | 6 ++-- GETTING_STARTED.md | 9 +++--- README.md | 8 ++--- baselines/README.md | 3 -- .../README.md | 28 ++++++++-------- .../external_tuning/jax_nadamw_full_budget.py | 0 .../jax_nadamw_target_setting.py | 0 .../pytorch_nadamw_full_budget.py | 0 .../pytorch_nadamw_target_setting.py | 0 .../external_tuning/tuning_search_space.json | 0 .../self_tuning/jax_nadamw_full_budget.py | 0 .../self_tuning/jax_nadamw_target_setting.py | 0 .../self_tuning/pytorch_nadamw_full_budget.py | 0 .../pytorch_nadamw_target_setting.py | 0 .../paper_baselines/README.md | 14 ++++++++ .../paper_baselines}/__init__.py | 0 .../paper_baselines}/adafactor/__init__.py | 0 .../adafactor/jax/__init__.py | 0 .../adafactor/jax/sharded_adafactor.py | 0 .../adafactor/jax/submission.py | 3 +- .../adafactor/pytorch/__init__.py | 0 .../adafactor/pytorch/submission.py | 0 .../adafactor/tuning_search_space.json | 0 .../tuning_search_space_no_beta1.json | 0 .../paper_baselines}/adamw/__init__.py | 0 .../paper_baselines}/adamw/jax/__init__.py | 0 .../paper_baselines}/adamw/jax/submission.py | 0 .../adamw/pytorch/__init__.py | 0 .../adamw/pytorch/submission.py | 0 .../adamw/tuning_search_space.json | 0 .../adamw/tuning_search_space_no_beta1.json | 0 .../paper_baselines}/lamb/__init__.py | 0 .../paper_baselines}/lamb/jax/__init__.py | 0 .../paper_baselines}/lamb/jax/submission.py | 0 .../paper_baselines}/lamb/pytorch/__init__.py | 0 .../lamb/pytorch/submission.py | 0 .../lamb/tuning_search_space.json | 0 .../lamb/tuning_search_space_no_beta1.json | 0 .../paper_baselines}/momentum/__init__.py | 0 .../paper_baselines}/momentum/jax/__init__.py | 0 .../momentum/jax/submission.py | 0 .../momentum/pytorch/__init__.py | 0 .../momentum/pytorch/submission.py | 0 .../momentum/tuning_search_space.json | 0 .../tuning_search_space_no_beta1.json | 0 .../paper_baselines}/nadamw/__init__.py | 0 .../paper_baselines}/nadamw/jax/__init__.py | 0 .../paper_baselines}/nadamw/jax/submission.py | 0 .../nadamw/pytorch/__init__.py | 0 .../nadamw/pytorch/submission.py | 0 .../nadamw/tuning_search_space.json | 0 .../nadamw/tuning_search_space_no_beta1.json | 0 .../paper_baselines}/nesterov/__init__.py | 0 .../paper_baselines}/nesterov/jax/__init__.py | 0 .../nesterov/jax/submission.py | 0 .../nesterov/pytorch/__init__.py | 0 .../nesterov/pytorch/submission.py | 0 .../nesterov/tuning_search_space.json | 0 .../tuning_search_space_no_beta1.json | 0 .../paper_baselines}/sam/__init__.py | 0 .../paper_baselines}/sam/jax/__init__.py | 0 .../paper_baselines}/sam/jax/submission.py | 0 .../paper_baselines}/sam/pytorch/__init__.py | 0 .../sam/pytorch/submission.py | 0 .../sam/tuning_search_space.json | 0 .../sam/tuning_search_space_no_beta1.json | 0 .../paper_baselines}/shampoo/__init__.py | 0 .../paper_baselines}/shampoo/jax/__init__.py | 0 .../shampoo/jax/distributed_shampoo.py | 0 .../shampoo/jax/submission.py | 3 +- .../shampoo/pytorch/__init__.py | 0 .../shampoo/tuning_search_space.json | 0 .../shampoo/tuning_search_space_no_beta1.json | 0 tests/test_baselines.py | 7 ++-- 79 files changed, 73 insertions(+), 58 deletions(-) delete mode 100644 baselines/README.md rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/README.md (52%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/external_tuning/jax_nadamw_full_budget.py (100%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/external_tuning/jax_nadamw_target_setting.py (100%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/external_tuning/pytorch_nadamw_full_budget.py (100%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/external_tuning/pytorch_nadamw_target_setting.py (100%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/external_tuning/tuning_search_space.json (100%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/self_tuning/jax_nadamw_full_budget.py (100%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/self_tuning/jax_nadamw_target_setting.py (100%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/self_tuning/pytorch_nadamw_full_budget.py (100%) rename {reference_algorithms/prize_qualification_baselines => prize_qualification_baselines}/self_tuning/pytorch_nadamw_target_setting.py (100%) create mode 100644 reference_algorithms/paper_baselines/README.md rename {baselines => reference_algorithms/paper_baselines}/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adafactor/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adafactor/jax/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adafactor/jax/sharded_adafactor.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adafactor/jax/submission.py (98%) rename {baselines => reference_algorithms/paper_baselines}/adafactor/pytorch/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adafactor/pytorch/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adafactor/tuning_search_space.json (100%) rename {baselines => reference_algorithms/paper_baselines}/adafactor/tuning_search_space_no_beta1.json (100%) rename {baselines => reference_algorithms/paper_baselines}/adamw/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adamw/jax/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adamw/jax/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adamw/pytorch/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adamw/pytorch/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/adamw/tuning_search_space.json (100%) rename {baselines => reference_algorithms/paper_baselines}/adamw/tuning_search_space_no_beta1.json (100%) rename {baselines => reference_algorithms/paper_baselines}/lamb/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/lamb/jax/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/lamb/jax/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/lamb/pytorch/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/lamb/pytorch/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/lamb/tuning_search_space.json (100%) rename {baselines => reference_algorithms/paper_baselines}/lamb/tuning_search_space_no_beta1.json (100%) rename {baselines => reference_algorithms/paper_baselines}/momentum/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/momentum/jax/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/momentum/jax/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/momentum/pytorch/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/momentum/pytorch/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/momentum/tuning_search_space.json (100%) rename {baselines => reference_algorithms/paper_baselines}/momentum/tuning_search_space_no_beta1.json (100%) rename {baselines => reference_algorithms/paper_baselines}/nadamw/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nadamw/jax/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nadamw/jax/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nadamw/pytorch/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nadamw/pytorch/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nadamw/tuning_search_space.json (100%) rename {baselines => reference_algorithms/paper_baselines}/nadamw/tuning_search_space_no_beta1.json (100%) rename {baselines => reference_algorithms/paper_baselines}/nesterov/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nesterov/jax/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nesterov/jax/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nesterov/pytorch/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nesterov/pytorch/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/nesterov/tuning_search_space.json (100%) rename {baselines => reference_algorithms/paper_baselines}/nesterov/tuning_search_space_no_beta1.json (100%) rename {baselines => reference_algorithms/paper_baselines}/sam/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/sam/jax/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/sam/jax/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/sam/pytorch/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/sam/pytorch/submission.py (100%) rename {baselines => reference_algorithms/paper_baselines}/sam/tuning_search_space.json (100%) rename {baselines => reference_algorithms/paper_baselines}/sam/tuning_search_space_no_beta1.json (100%) rename {baselines => reference_algorithms/paper_baselines}/shampoo/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/shampoo/jax/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/shampoo/jax/distributed_shampoo.py (100%) rename {baselines => reference_algorithms/paper_baselines}/shampoo/jax/submission.py (98%) rename {baselines => reference_algorithms/paper_baselines}/shampoo/pytorch/__init__.py (100%) rename {baselines => reference_algorithms/paper_baselines}/shampoo/tuning_search_space.json (100%) rename {baselines => reference_algorithms/paper_baselines}/shampoo/tuning_search_space_no_beta1.json (100%) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index d89ae8887..89b5ef288 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -18,8 +18,8 @@ jobs: - name: Run pylint run: | pylint algorithmic_efficiency - pylint baselines pylint reference_algorithms + pylint prize_qualification_baselines pylint submission_runner.py pylint tests diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml index 3a0736fa2..cb8595f58 100644 --- a/.github/workflows/regression_tests.yml +++ b/.github/workflows/regression_tests.yml @@ -44,7 +44,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fastmri -f jax -s baselines/adamw/jax/submission.py -w fastmri -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fastmri -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false imagenet_resnet_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -53,7 +53,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s baselines/adamw/jax/submission.py -w imagenet_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false imagenet_vit_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -62,7 +62,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s baselines/adamw/jax/submission.py -w imagenet_vit -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false ogbg_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -71,7 +71,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d ogbg -f jax -s baselines/adamw/jax/submission.py -w ogbg -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d ogbg -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -80,7 +80,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false librispeech_conformer_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -89,7 +89,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_conformer -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false librispeech_deepspeech_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -98,7 +98,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_deepspeech -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false wmt_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -107,7 +107,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s baselines/adamw/jax/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false fastmri_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -116,7 +116,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d fastmri -f pytorch -s baselines/adamw/pytorch/submission.py -w fastmri -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d fastmri -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false imagenet_resnet_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -125,7 +125,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s baselines/adamw/pytorch/submission.py -w imagenet_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false imagenet_vit_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -134,7 +134,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s baselines/adamw/pytorch/submission.py -w imagenet_vit -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false ogbg_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -143,7 +143,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d ogbg -f pytorch -s baselines/adamw/pytorch/submission.py -w ogbg -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d ogbg -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -152,7 +152,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false exit $? librispeech_conformer_pytorch: runs-on: self-hosted @@ -162,7 +162,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s baselines/adamw/pytorch/submission.py -w librispeech_conformer -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false librispeech_deepspeech_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -171,7 +171,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s baselines/adamw/pytorch/submission.py -w librispeech_deepspeech -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false wmt_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -180,4 +180,4 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s baselines/adamw/pytorch/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index 15eccba4c..ef1585d0d 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -44,7 +44,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -53,7 +53,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w criteo1tb_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_layernorm_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -62,7 +62,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -71,7 +71,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -80,6 +80,6 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_embed_init -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb_embed_init -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false diff --git a/COMPETITION_RULES.md b/COMPETITION_RULES.md index 92e1af072..beca743e0 100644 --- a/COMPETITION_RULES.md +++ b/COMPETITION_RULES.md @@ -79,7 +79,7 @@ Submissions must use specific versions of PyTorch and JAX, provided by Sponsor. ## Scoring -All otherwise qualified Submissions shall be scored. Submissions will be scored based on their required training time to reach the target performance on the validation set of each workload, using measuring techniques designed to give all Submissions equal parity. In the event that no Submission in a ruleset receives a score exceeding that of both [prize qualification baselines](./reference_algorithms/prize_qualification_baselines/README.md), no prizes will be awarded for this ruleset. The Teams with the highest scores will be determined to be winners ("Selected Teams"). In the event of a tie the prize money will be split equally between the winners. +All otherwise qualified Submissions shall be scored. Submissions will be scored based on their required training time to reach the target performance on the validation set of each workload, using measuring techniques designed to give all Submissions equal parity. In the event that no Submission in a ruleset receives a score exceeding that of both [prize qualification baselines](./prize_qualification_baselines/README.md), no prizes will be awarded for this ruleset. The Teams with the highest scores will be determined to be winners ("Selected Teams"). In the event of a tie the prize money will be split equally between the winners. ## Submissions diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b22cb5f3a..364bbee62 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -228,7 +228,7 @@ To run the below commands, use the versions installed via `pip install -e '.[dev To automatically fix formatting errors, run the following (*WARNING:* this will edit your code, so it is suggested to make a git commit first!): ```bash -yapf -i -r -vv -p algorithmic_efficiency baselines datasets reference_algorithms tests *.py +yapf -i -r -vv -p algorithmic_efficiency datasets prize_qualification_baselines reference_algorithms tests *.py ``` To sort all import orderings, run the following: @@ -247,8 +247,8 @@ To print out all offending pylint issues, run the following: ```bash pylint algorithmic_efficiency -pylint baselines pylint datasets +pylint prize_qualification_baselines pylint reference_algorithms pylint submission_runner.py pylint tests diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index de7a3b7f8..a25f5b689 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -38,10 +38,9 @@ - [How can I know if my code can be run on benchmarking hardware?](#how-can-i-know-if-my-code-can-be-run-on-benchmarking-hardware) - [Are we allowed to use our own hardware to self-report the results?](#are-we-allowed-to-use-our-own-hardware-to-self-report-the-results) - [What can I do if running the benchmark is too expensive for me?](#what-can-i-do-if-running-the-benchmark-is-too-expensive-for-me) - - [Can I submit existing (i.e. published) training algorithms as submissions?](#can-i-submit-previously-published-training-algorithms-as-submissions) + - [Can I submit previously published training algorithms as submissions?](#can-i-submit-previously-published-training-algorithms-as-submissions) - [Disclaimers](#disclaimers) - [Shared Data Pipelines between JAX and PyTorch](#shared-data-pipelines-between-jax-and-pytorch) - - [Pytorch Conformer CUDA OOM](#pytorch-conformer-cuda-oom) ## Introduction @@ -517,7 +516,7 @@ To ensure that all submitters can develop their submissions based on the same co #### My machine only has one GPU. How can I use this repo? -You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes in our reference algorithms `algorithmic-efficiency/baselines` and `algorithmic-efficiency/reference_algorithms` are tuned for a machine with 8 16GB V100 GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' on the benchmarking hardware, so if you are using fewer +You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes in our reference algorithms (e.g. `algorithmic-efficiency/prize_qualification_baselines` and `algorithmic-efficiency/reference_algorithms`) are tuned for a machine with 8 16GB V100 GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' on the benchmarking hardware, so if you are using fewer GPUs with higher per GPU memory, please monitor your memory usage to make sure it will fit on 8xV100 GPUs with 16GB of VRAM per card. #### How do I run this on my SLURM cluster? @@ -576,4 +575,3 @@ The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details. While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. - diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index b13f9f00c..96a7b7d6f 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -163,6 +163,7 @@ singularity build --fakeroot .sif Singularity.def ``` Note that this can take several minutes. Then, to start a shell session with GPU support (by using the `--nv` flag), we can run + ```bash singularity shell --bind $HOME/data:/data,$HOME/experiment_runs:/experiment_runs \ --nv .sif @@ -194,7 +195,7 @@ Make a submissions subdirectory to store your submission modules e.g. `algorithm ### Coding your Submission -You can find examples of sumbission modules under `algorithmic-efficiency/baselines` and `algorithmic-efficiency/reference_algorithms`. \ +You can find examples of submission modules under `algorithmic-efficiency/prize_qualification_baselines` and `algorithmic-efficiency/reference_algorithms`. \ A submission for the external ruleset will consist of a submission module and a tuning search space definition. 1. Copy the template submission module `submissions/template/submission.py` into your submissions directory e.g. in `algorithmic-efficiency/my_submissions`. @@ -210,7 +211,7 @@ A submission for the external ruleset will consist of a submission module and a } ``` - For a complete example see [tuning_search_space.json](https://github.com/mlcommons/algorithmic-efficiency/blob/main/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json). + For a complete example see [tuning_search_space.json](/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json). 2. Define a range of values for quasirandom sampling by specifing a `min`, `max` and `scaling` keys for the hyperparameter: @@ -224,7 +225,7 @@ A submission for the external ruleset will consist of a submission module and a } ``` - For a complete example see [tuning_search_space.json](https://github.com/mlcommons/algorithmic-efficiency/blob/main/baselines/nadamw/tuning_search_space.json). + For a complete example see [tuning_search_space.json](/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json). ## Run your Submission @@ -342,6 +343,6 @@ To produce performance profile and performance table: python3 scoring/score_submission.py --experiment_path= --output_dir= ``` -We provide the scores and performance profiles for the baseline algorithms in the "Baseline Results" section in [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179). +We provide the scores and performance profiles for the [paper baseline algorithms](/reference_algorithms/paper_baselines/) in the "Baseline Results" section in [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179). **Good Luck!** diff --git a/README.md b/README.md index 06cd7200d..c1c8c961b 100644 --- a/README.md +++ b/README.md @@ -75,8 +75,8 @@ python3 submission_runner.py \ --workload=mnist \ --experiment_dir=$HOME/experiments \ --experiment_name=my_first_experiment \ - --submission_path=baselines/adamw/jax/submission.py \ - --tuning_search_space=baselines/adamw/tuning_search_space.json + --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \ + --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json ``` *TL;DR running a PyTorch workload:* @@ -87,8 +87,8 @@ python3 submission_runner.py \ --workload=mnist \ --experiment_dir=$HOME/experiments \ --experiment_name=my_first_experiment \ - --submission_path=baselines/adamw/jax/submission.py \ - --tuning_search_space=baselines/adamw/tuning_search_space.json + --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \ + --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json ``` ## Call for Submissions diff --git a/baselines/README.md b/baselines/README.md deleted file mode 100644 index 76f2b9ba0..000000000 --- a/baselines/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Baseline submissions - -Valid baseline submissions for the [external tuning ruleset](../README.md#external-tuning-ruleset). diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/prize_qualification_baselines/README.md similarity index 52% rename from reference_algorithms/prize_qualification_baselines/README.md rename to prize_qualification_baselines/README.md index 100555964..f5bb007be 100644 --- a/reference_algorithms/prize_qualification_baselines/README.md +++ b/prize_qualification_baselines/README.md @@ -8,8 +8,8 @@ This directory contains the baseline(s) that submissions must beat to qualify fo The prize qualification baseline submissions for JAX are: -- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` +- `prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` Example command: @@ -20,16 +20,16 @@ python3 submission_runner.py \ --experiment_dir= \ --experiment_name= \ --workload= \ - --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json + --submission_path=prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=prize_qualification_baselines/external_tuning/tuning_search_space.json ``` ### PyTorch The prize qualification baseline submissionss for PyTorch are: -- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` +- `prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` Example command: @@ -40,8 +40,8 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc --experiment_dir= \ --experiment_name=t \ --workload=\ - --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json + --submission_path=prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=prize_qualification_baselines/external_tuning/tuning_search_space.json ``` ## Self-tuning Ruleset @@ -50,8 +50,8 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc The prize qualification baseline submissionss for jax are: -- `reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py` +- `prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py` +- `prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py` Example command: @@ -62,7 +62,7 @@ python3 submission_runner.py \ --experiment_dir= \ --experiment_name= \ --workload= \ - --submission_path=reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py \ + --submission_path=prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py \ --tuning_ruleset=self ``` @@ -70,8 +70,8 @@ python3 submission_runner.py \ The prize qualification baseline submissionss for PyTorch are: -- `reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py` +- `prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py` +- `prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py` Example command: @@ -82,6 +82,6 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc --experiment_dir= \ --experiment_name=t \ --workload=\ - --submission_path=reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py \ + --submission_path=prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py \ --tuning_ruleset=self ``` diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py rename to prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py rename to prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py rename to prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py rename to prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json similarity index 100% rename from reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json rename to prize_qualification_baselines/external_tuning/tuning_search_space.json diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py rename to prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py rename to prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py rename to prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py rename to prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/paper_baselines/README.md b/reference_algorithms/paper_baselines/README.md new file mode 100644 index 000000000..aadb7eab2 --- /dev/null +++ b/reference_algorithms/paper_baselines/README.md @@ -0,0 +1,14 @@ +# Baseline Submissions from the "Benchmarking Neural Network Training Algorithms" Paper + +This directory contains the baseline submissions for the [external tuning ruleset](../README.md#external-tuning-ruleset) as presented in our paper [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179). They are based on eight different update rules: + +- [Adafactor](/reference_algorithms/paper_baselines/adafactor) +- [AdamW](/reference_algorithms/paper_baselines/adamw) +- [LAMB](/reference_algorithms/paper_baselines/lamb) +- [SGD with Momentum](/reference_algorithms/paper_baselines/momentum) +- [NadamW](/reference_algorithms/paper_baselines/nadamw) +- [SGD with Nesterov Momentum](/reference_algorithms/paper_baselines/nesterov) +- [SAM](/reference_algorithms/paper_baselines/sam) +- [Shampoo](/reference_algorithms/paper_baselines/shampoo/) + +Each update rule has two different tuning search spaces, one where the first momentum parameter (often denoted $\beta_1$) is tuned and one where it is set to a fixed value. diff --git a/baselines/__init__.py b/reference_algorithms/paper_baselines/__init__.py similarity index 100% rename from baselines/__init__.py rename to reference_algorithms/paper_baselines/__init__.py diff --git a/baselines/adafactor/__init__.py b/reference_algorithms/paper_baselines/adafactor/__init__.py similarity index 100% rename from baselines/adafactor/__init__.py rename to reference_algorithms/paper_baselines/adafactor/__init__.py diff --git a/baselines/adafactor/jax/__init__.py b/reference_algorithms/paper_baselines/adafactor/jax/__init__.py similarity index 100% rename from baselines/adafactor/jax/__init__.py rename to reference_algorithms/paper_baselines/adafactor/jax/__init__.py diff --git a/baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py similarity index 100% rename from baselines/adafactor/jax/sharded_adafactor.py rename to reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py diff --git a/baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py similarity index 98% rename from baselines/adafactor/jax/submission.py rename to reference_algorithms/paper_baselines/adafactor/jax/submission.py index ec8020e7e..2dd85c29b 100644 --- a/baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -10,7 +10,8 @@ import optax from algorithmic_efficiency import spec -from baselines.adafactor.jax.sharded_adafactor import sharded_adafactor +from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import \ + sharded_adafactor _GRAD_CLIP_EPS = 1e-6 diff --git a/baselines/adafactor/pytorch/__init__.py b/reference_algorithms/paper_baselines/adafactor/pytorch/__init__.py similarity index 100% rename from baselines/adafactor/pytorch/__init__.py rename to reference_algorithms/paper_baselines/adafactor/pytorch/__init__.py diff --git a/baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py similarity index 100% rename from baselines/adafactor/pytorch/submission.py rename to reference_algorithms/paper_baselines/adafactor/pytorch/submission.py diff --git a/baselines/adafactor/tuning_search_space.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json similarity index 100% rename from baselines/adafactor/tuning_search_space.json rename to reference_algorithms/paper_baselines/adafactor/tuning_search_space.json diff --git a/baselines/adafactor/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json similarity index 100% rename from baselines/adafactor/tuning_search_space_no_beta1.json rename to reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json diff --git a/baselines/adamw/__init__.py b/reference_algorithms/paper_baselines/adamw/__init__.py similarity index 100% rename from baselines/adamw/__init__.py rename to reference_algorithms/paper_baselines/adamw/__init__.py diff --git a/baselines/adamw/jax/__init__.py b/reference_algorithms/paper_baselines/adamw/jax/__init__.py similarity index 100% rename from baselines/adamw/jax/__init__.py rename to reference_algorithms/paper_baselines/adamw/jax/__init__.py diff --git a/baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py similarity index 100% rename from baselines/adamw/jax/submission.py rename to reference_algorithms/paper_baselines/adamw/jax/submission.py diff --git a/baselines/adamw/pytorch/__init__.py b/reference_algorithms/paper_baselines/adamw/pytorch/__init__.py similarity index 100% rename from baselines/adamw/pytorch/__init__.py rename to reference_algorithms/paper_baselines/adamw/pytorch/__init__.py diff --git a/baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py similarity index 100% rename from baselines/adamw/pytorch/submission.py rename to reference_algorithms/paper_baselines/adamw/pytorch/submission.py diff --git a/baselines/adamw/tuning_search_space.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json similarity index 100% rename from baselines/adamw/tuning_search_space.json rename to reference_algorithms/paper_baselines/adamw/tuning_search_space.json diff --git a/baselines/adamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json similarity index 100% rename from baselines/adamw/tuning_search_space_no_beta1.json rename to reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json diff --git a/baselines/lamb/__init__.py b/reference_algorithms/paper_baselines/lamb/__init__.py similarity index 100% rename from baselines/lamb/__init__.py rename to reference_algorithms/paper_baselines/lamb/__init__.py diff --git a/baselines/lamb/jax/__init__.py b/reference_algorithms/paper_baselines/lamb/jax/__init__.py similarity index 100% rename from baselines/lamb/jax/__init__.py rename to reference_algorithms/paper_baselines/lamb/jax/__init__.py diff --git a/baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py similarity index 100% rename from baselines/lamb/jax/submission.py rename to reference_algorithms/paper_baselines/lamb/jax/submission.py diff --git a/baselines/lamb/pytorch/__init__.py b/reference_algorithms/paper_baselines/lamb/pytorch/__init__.py similarity index 100% rename from baselines/lamb/pytorch/__init__.py rename to reference_algorithms/paper_baselines/lamb/pytorch/__init__.py diff --git a/baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py similarity index 100% rename from baselines/lamb/pytorch/submission.py rename to reference_algorithms/paper_baselines/lamb/pytorch/submission.py diff --git a/baselines/lamb/tuning_search_space.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json similarity index 100% rename from baselines/lamb/tuning_search_space.json rename to reference_algorithms/paper_baselines/lamb/tuning_search_space.json diff --git a/baselines/lamb/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json similarity index 100% rename from baselines/lamb/tuning_search_space_no_beta1.json rename to reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json diff --git a/baselines/momentum/__init__.py b/reference_algorithms/paper_baselines/momentum/__init__.py similarity index 100% rename from baselines/momentum/__init__.py rename to reference_algorithms/paper_baselines/momentum/__init__.py diff --git a/baselines/momentum/jax/__init__.py b/reference_algorithms/paper_baselines/momentum/jax/__init__.py similarity index 100% rename from baselines/momentum/jax/__init__.py rename to reference_algorithms/paper_baselines/momentum/jax/__init__.py diff --git a/baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py similarity index 100% rename from baselines/momentum/jax/submission.py rename to reference_algorithms/paper_baselines/momentum/jax/submission.py diff --git a/baselines/momentum/pytorch/__init__.py b/reference_algorithms/paper_baselines/momentum/pytorch/__init__.py similarity index 100% rename from baselines/momentum/pytorch/__init__.py rename to reference_algorithms/paper_baselines/momentum/pytorch/__init__.py diff --git a/baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py similarity index 100% rename from baselines/momentum/pytorch/submission.py rename to reference_algorithms/paper_baselines/momentum/pytorch/submission.py diff --git a/baselines/momentum/tuning_search_space.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json similarity index 100% rename from baselines/momentum/tuning_search_space.json rename to reference_algorithms/paper_baselines/momentum/tuning_search_space.json diff --git a/baselines/momentum/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json similarity index 100% rename from baselines/momentum/tuning_search_space_no_beta1.json rename to reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json diff --git a/baselines/nadamw/__init__.py b/reference_algorithms/paper_baselines/nadamw/__init__.py similarity index 100% rename from baselines/nadamw/__init__.py rename to reference_algorithms/paper_baselines/nadamw/__init__.py diff --git a/baselines/nadamw/jax/__init__.py b/reference_algorithms/paper_baselines/nadamw/jax/__init__.py similarity index 100% rename from baselines/nadamw/jax/__init__.py rename to reference_algorithms/paper_baselines/nadamw/jax/__init__.py diff --git a/baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py similarity index 100% rename from baselines/nadamw/jax/submission.py rename to reference_algorithms/paper_baselines/nadamw/jax/submission.py diff --git a/baselines/nadamw/pytorch/__init__.py b/reference_algorithms/paper_baselines/nadamw/pytorch/__init__.py similarity index 100% rename from baselines/nadamw/pytorch/__init__.py rename to reference_algorithms/paper_baselines/nadamw/pytorch/__init__.py diff --git a/baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py similarity index 100% rename from baselines/nadamw/pytorch/submission.py rename to reference_algorithms/paper_baselines/nadamw/pytorch/submission.py diff --git a/baselines/nadamw/tuning_search_space.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json similarity index 100% rename from baselines/nadamw/tuning_search_space.json rename to reference_algorithms/paper_baselines/nadamw/tuning_search_space.json diff --git a/baselines/nadamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json similarity index 100% rename from baselines/nadamw/tuning_search_space_no_beta1.json rename to reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json diff --git a/baselines/nesterov/__init__.py b/reference_algorithms/paper_baselines/nesterov/__init__.py similarity index 100% rename from baselines/nesterov/__init__.py rename to reference_algorithms/paper_baselines/nesterov/__init__.py diff --git a/baselines/nesterov/jax/__init__.py b/reference_algorithms/paper_baselines/nesterov/jax/__init__.py similarity index 100% rename from baselines/nesterov/jax/__init__.py rename to reference_algorithms/paper_baselines/nesterov/jax/__init__.py diff --git a/baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py similarity index 100% rename from baselines/nesterov/jax/submission.py rename to reference_algorithms/paper_baselines/nesterov/jax/submission.py diff --git a/baselines/nesterov/pytorch/__init__.py b/reference_algorithms/paper_baselines/nesterov/pytorch/__init__.py similarity index 100% rename from baselines/nesterov/pytorch/__init__.py rename to reference_algorithms/paper_baselines/nesterov/pytorch/__init__.py diff --git a/baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py similarity index 100% rename from baselines/nesterov/pytorch/submission.py rename to reference_algorithms/paper_baselines/nesterov/pytorch/submission.py diff --git a/baselines/nesterov/tuning_search_space.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json similarity index 100% rename from baselines/nesterov/tuning_search_space.json rename to reference_algorithms/paper_baselines/nesterov/tuning_search_space.json diff --git a/baselines/nesterov/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json similarity index 100% rename from baselines/nesterov/tuning_search_space_no_beta1.json rename to reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json diff --git a/baselines/sam/__init__.py b/reference_algorithms/paper_baselines/sam/__init__.py similarity index 100% rename from baselines/sam/__init__.py rename to reference_algorithms/paper_baselines/sam/__init__.py diff --git a/baselines/sam/jax/__init__.py b/reference_algorithms/paper_baselines/sam/jax/__init__.py similarity index 100% rename from baselines/sam/jax/__init__.py rename to reference_algorithms/paper_baselines/sam/jax/__init__.py diff --git a/baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py similarity index 100% rename from baselines/sam/jax/submission.py rename to reference_algorithms/paper_baselines/sam/jax/submission.py diff --git a/baselines/sam/pytorch/__init__.py b/reference_algorithms/paper_baselines/sam/pytorch/__init__.py similarity index 100% rename from baselines/sam/pytorch/__init__.py rename to reference_algorithms/paper_baselines/sam/pytorch/__init__.py diff --git a/baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py similarity index 100% rename from baselines/sam/pytorch/submission.py rename to reference_algorithms/paper_baselines/sam/pytorch/submission.py diff --git a/baselines/sam/tuning_search_space.json b/reference_algorithms/paper_baselines/sam/tuning_search_space.json similarity index 100% rename from baselines/sam/tuning_search_space.json rename to reference_algorithms/paper_baselines/sam/tuning_search_space.json diff --git a/baselines/sam/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json similarity index 100% rename from baselines/sam/tuning_search_space_no_beta1.json rename to reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json diff --git a/baselines/shampoo/__init__.py b/reference_algorithms/paper_baselines/shampoo/__init__.py similarity index 100% rename from baselines/shampoo/__init__.py rename to reference_algorithms/paper_baselines/shampoo/__init__.py diff --git a/baselines/shampoo/jax/__init__.py b/reference_algorithms/paper_baselines/shampoo/jax/__init__.py similarity index 100% rename from baselines/shampoo/jax/__init__.py rename to reference_algorithms/paper_baselines/shampoo/jax/__init__.py diff --git a/baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py similarity index 100% rename from baselines/shampoo/jax/distributed_shampoo.py rename to reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py diff --git a/baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py similarity index 98% rename from baselines/shampoo/jax/submission.py rename to reference_algorithms/paper_baselines/shampoo/jax/submission.py index cb062faf3..9c6b66b7f 100644 --- a/baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -10,7 +10,8 @@ import optax from algorithmic_efficiency import spec -from baselines.shampoo.jax.distributed_shampoo import distributed_shampoo +from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import \ + distributed_shampoo _GRAD_CLIP_EPS = 1e-6 diff --git a/baselines/shampoo/pytorch/__init__.py b/reference_algorithms/paper_baselines/shampoo/pytorch/__init__.py similarity index 100% rename from baselines/shampoo/pytorch/__init__.py rename to reference_algorithms/paper_baselines/shampoo/pytorch/__init__.py diff --git a/baselines/shampoo/tuning_search_space.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json similarity index 100% rename from baselines/shampoo/tuning_search_space.json rename to reference_algorithms/paper_baselines/shampoo/tuning_search_space.json diff --git a/baselines/shampoo/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json similarity index 100% rename from baselines/shampoo/tuning_search_space_no_beta1.json rename to reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json diff --git a/tests/test_baselines.py b/tests/test_baselines.py index 0a26aa69d..f79e629e7 100644 --- a/tests/test_baselines.py +++ b/tests/test_baselines.py @@ -47,6 +47,8 @@ 'jax', ] +baseline_path = "reference_algorithms/paper_baselines" + named_parameters = [] for f in frameworks: for b in baselines[f]: @@ -55,8 +57,9 @@ testcase_name=f'{b}_{f}', workload='mnist', framework=f'{f}', - submission_path=f'baselines/{b}/{f}/submission.py', - tuning_search_space=f'baselines/{b}/tuning_search_space.json')) + submission_path=f'{baseline_path}/{b}/{f}/submission.py', + tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json') + ) class BaselineTest(parameterized.TestCase): From 22ab1a7c82dd973fe55f2e49cb367967109050eb Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Fri, 12 Jan 2024 11:53:37 +0100 Subject: [PATCH 084/155] Add discord badge and callout to contact us --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index c1c8c961b..65bae4d54 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ [![Lint](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml/badge.svg)](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml) [![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/mlcommons/algorithmic-efficiency/blob/main/LICENSE.md) [![Code style: yapf](https://img.shields.io/badge/code%20style-yapf-orange)](https://github.com/google/yapf) +[![Discord](https://dcbadge.vercel.app/api/server/5FPXK7SMt6?style=flat)](https://discord.gg/5FPXK7SMt6) --- @@ -43,6 +44,9 @@ ## Installation +> [!TIP] +> **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/). + You can install this package and dependencies in a [Python virtual environment](/GETTING_STARTED.md#python-virtual-environment) or use a [Docker/Singularity/Apptainer container](/GETTING_STARTED.md#docker) (recommended). We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments. Both options are described in detail in the [**Getting Started**](/GETTING_STARTED.md) document. From a8187b80cb6389a7efe629c2f0d82cdc6f540072 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 13 Jan 2024 14:46:52 +0000 Subject: [PATCH 085/155] Add pass/fail thresholds to traindiffs test --- tests/test_traindiffs.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index a1b64a573..d7ef3d3ae 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -10,6 +10,7 @@ from absl import flags from absl.testing import absltest +from numpy import allclose FLAGS = flags.FLAGS @@ -81,6 +82,17 @@ def test_workload(self): print(header) print('=' * len(header)) for i in range(NUM_TRAIN_STEPS): + rtol = 1e-1 if workload == 'librispeech_deepspeech' else 5e-3 + self.assertTrue(allclose(jax_results['eval_results'][i][k], + pyt_results['eval_results'][i][k], + rtol=rtol)) + self.assertTrue(allclose(jax_results['scalars'][i]['grad_norm'], + pyt_results['scalars'][i]['grad_norm'], + rtol=rtol)) + self.assertTrue(allclose(jax_results['scalars'][i]['loss'], + pyt_results['scalars'][i]['loss'], + rtol=rtol)) + row = map(lambda x: str(round(x, 5)), [ jax_results['eval_results'][i][k], From 2373e1599b9052f7d4cc3522d88c5c3a6aa5ceb2 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 13 Jan 2024 14:47:20 +0000 Subject: [PATCH 086/155] Add traindiffs_test option to docker startup script --- docker/scripts/startup.sh | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 30cb6b36b..0d9ad1b1e 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -14,8 +14,8 @@ function usage() { $0 [--dataset dataset] [--framework framework] [--submission_path submission_path] [--tuning_search_space tuning_search_space] [--experiment_name experiment_name] [--workload workload] [--max_global_steps max_global_steps] [--rsync_data rsync_data] - [--internal_contributor true] - + [--internal_contributor true] [--traindiffs_test false] + Options: -d | --dataset: Can be imagenet, criteo1tb, ogbg, fastmri, wmt, librispeech. -f | --framework: Can be jax or pytorch. @@ -34,11 +34,13 @@ function usage() { from internal GCP bucket. -i | --internal_contributor: If true, allow rsync of data and transfer of experiment results with GCP project. + --traindiffs_test: If true, ignore all other options and run the traindiffs test. USAGE exit 1 } # Defaults +TEST="false" INTERNAL_CONTRIBUTOR_MODE="false" HOME_DIR="" RSYNC_DATA="true" @@ -47,7 +49,11 @@ SAVE_CHECKPOINTS="true" # Pass flag while [ "$1" != "" ]; do - case $1 in + case $1 in + --traindiffs_test) + shift + TEST=$1 + ;; -d | --dataset) shift DATASET=$1 @@ -106,8 +112,15 @@ while [ "$1" != "" ]; do ;; esac shift -done - +done + +if [[ ${TEST} == "true" ]]; then + cd algorithmic-efficiency + COMMAND="python3 tests/test_traindiffs.py" + echo $COMMAND + eval $COMMAND + exit +fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ From d1da9c7651bc4edf311b3b28ba7c3f232ad3ef5c Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 13 Jan 2024 16:01:57 +0100 Subject: [PATCH 087/155] Rename PytWorkload to PyTorchWorkload --- tests/modeldiffs/criteo1tb/compare.py | 4 ++-- .../criteo1tb_embed_init/compare.py | 4 ++-- .../modeldiffs/criteo1tb_layernorm/compare.py | 4 ++-- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ++-- tests/modeldiffs/fastmri/compare.py | 4 ++-- tests/modeldiffs/fastmri_layernorm/compare.py | 4 ++-- .../modeldiffs/fastmri_model_size/compare.py | 4 ++-- tests/modeldiffs/fastmri_tanh/compare.py | 4 ++-- tests/modeldiffs/imagenet_resnet/compare.py | 4 ++-- .../imagenet_resnet/gelu_compare.py | 4 ++-- .../imagenet_resnet/silu_compare.py | 4 ++-- tests/modeldiffs/imagenet_vit/compare.py | 4 ++-- tests/modeldiffs/imagenet_vit/glu_compare.py | 4 ++-- .../imagenet_vit/post_ln_compare.py | 4 ++-- .../librispeech_conformer/compare.py | 4 ++-- .../compare.py | 4 ++-- .../librispeech_conformer_gelu/compare.py | 4 ++-- .../compare.py | 4 ++-- .../librispeech_deepspeech/compare.py | 4 ++-- tests/modeldiffs/wmt/compare.py | 4 ++-- .../modeldiffs/wmt_attention_temp/compare.py | 4 ++-- tests/modeldiffs/wmt_glu_tanh/compare.py | 4 ++-- tests/modeldiffs/wmt_post_ln/compare.py | 4 ++-- tests/reference_algorithm_tests.py | 2 +- tests/test_traindiffs.py | 24 ++++++++++++------- 25 files changed, 62 insertions(+), 56 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 9a95f3656..adbade983 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallWorkload as PytWorkload + Criteo1TbDlrmSmallWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -51,7 +51,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index 719484037..0748e2d71 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallEmbedInitWorkload as PytWorkload + Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -50,7 +50,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 3fc2a750a..0a6e5c5ac 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallLayerNormWorkload as PytWorkload + Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index b9dbbc80e..288442594 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallResNetWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallResNetWorkload as PytWorkload + Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index 6780ff91e..56b74b32d 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRIWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIWorkload as PytWorkload + FastMRIWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 4be086da3..23ccf26d7 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRILayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRILayerNormWorkload as PytWorkload + FastMRILayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index 60d846b6f..b61516c29 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRIModelSizeWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIModelSizeWorkload as PytWorkload + FastMRIModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 47bad372a..0f455387c 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRITanhWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRITanhWorkload as PytWorkload + FastMRITanhWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index 2fc721ab0..fb730f1bf 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetWorkload as PytWorkload + ImagenetResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -72,7 +72,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 8c3899076..6c8adbec2 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetGELUWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetGELUWorkload as PytWorkload + ImagenetResNetGELUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -19,7 +19,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index ee74e7bc9..7668cdbd9 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetSiLUWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetSiLUWorkload as PytWorkload + ImagenetResNetSiLUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -19,7 +19,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index bf7d6dfa5..ebf39e4c3 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitWorkload as PytWorkload + ImagenetVitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -75,7 +75,7 @@ def key_transform(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit/glu_compare.py b/tests/modeldiffs/imagenet_vit/glu_compare.py index 444f1230a..2c0aa546d 100644 --- a/tests/modeldiffs/imagenet_vit/glu_compare.py +++ b/tests/modeldiffs/imagenet_vit/glu_compare.py @@ -13,7 +13,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitGluWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitGluWorkload as PytWorkload + ImagenetVitGluWorkload as PyTorchWorkload sd_transform = None @@ -21,7 +21,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit/post_ln_compare.py b/tests/modeldiffs/imagenet_vit/post_ln_compare.py index 8bf0bef7e..0883b5676 100644 --- a/tests/modeldiffs/imagenet_vit/post_ln_compare.py +++ b/tests/modeldiffs/imagenet_vit/post_ln_compare.py @@ -13,7 +13,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetViTPostLNWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetViTPostLNWorkload as PytWorkload + ImagenetViTPostLNWorkload as PyTorchWorkload sd_transform = None @@ -21,7 +21,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index d414001dd..cfe6c7381 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload as PytWorkload + LibriSpeechConformerWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index 64612fbf0..8480fca02 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerAttentionTemperatureWorkload as PytWorkload + LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index 892040b57..caa9b09b9 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerGeluWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerGeluWorkload as PytWorkload + LibriSpeechConformerGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index 784fceb60..1a94d3c77 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerLayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerLayerNormWorkload as PytWorkload + LibriSpeechConformerLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index 12b79a517..edcc3ba87 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechWorkload as PytWorkload + LibriSpeechDeepSpeechWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -83,7 +83,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 8d0ee8411..41fc5ee17 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkload as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkload as PytWorkload + WmtWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index b50abd3ca..92ce4eb44 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadAttentionTemp as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadAttentionTemp as PytWorkload + WmtWorkloadAttentionTemp as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 1322ad0a0..b8d860479 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadGLUTanH as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadGLUTanH as PytWorkload + WmtWorkloadGLUTanH as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index bfd701736..3f5469d8d 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadPostLN as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadPostLN as PytWorkload + WmtWorkloadPostLN as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 5c43b233b..74c06e180 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -197,7 +197,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): f'tests.modeldiffs.{workload_name}.compare') jax_params, model_state, _ = diff_utils.torch2jax( jax_workload=super(), - pytorch_workload=compare_module.PytWorkload(**self.init_kwargs), + pytorch_workload=compare_module.PyTorchWorkload(**self.init_kwargs), key_transform=compare_module.key_transform, sd_transform=compare_module.sd_transform) return (FrozenDict(**jax_utils.replicate(jax_params)), diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index d7ef3d3ae..663cf3de4 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -83,15 +83,21 @@ def test_workload(self): print('=' * len(header)) for i in range(NUM_TRAIN_STEPS): rtol = 1e-1 if workload == 'librispeech_deepspeech' else 5e-3 - self.assertTrue(allclose(jax_results['eval_results'][i][k], - pyt_results['eval_results'][i][k], - rtol=rtol)) - self.assertTrue(allclose(jax_results['scalars'][i]['grad_norm'], - pyt_results['scalars'][i]['grad_norm'], - rtol=rtol)) - self.assertTrue(allclose(jax_results['scalars'][i]['loss'], - pyt_results['scalars'][i]['loss'], - rtol=rtol)) + self.assertTrue( + allclose( + jax_results['eval_results'][i][k], + pyt_results['eval_results'][i][k], + rtol=rtol)) + self.assertTrue( + allclose( + jax_results['scalars'][i]['grad_norm'], + pyt_results['scalars'][i]['grad_norm'], + rtol=rtol)) + self.assertTrue( + allclose( + jax_results['scalars'][i]['loss'], + pyt_results['scalars'][i]['loss'], + rtol=rtol)) row = map(lambda x: str(round(x, 5)), [ From 6a5d63a7868f622215cb0a68205beb89fad62bd9 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 13 Jan 2024 16:06:11 +0100 Subject: [PATCH 088/155] Add traindiffs tests to workflows (self-hosted) --- .github/workflows/traindiffs_tests.yml | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/traindiffs_tests.yml diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml new file mode 100644 index 000000000..5bb84c867 --- /dev/null +++ b/.github/workflows/traindiffs_tests.yml @@ -0,0 +1,32 @@ +name: Containerized training differences tests between Jax and PyTorch + +on: + pull_request: + branches: + - 'main' + +jobs: + build_and_push_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker image + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=both + IMAGE_NAME="algoperf_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + traindiffs_tests: + runs-on: self-hosted + needs: build_and_push_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized traindiffs test + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} algoperf_${{ github.head_ref || github.ref_name }} --traindiffs_test true From 1363e15df2dda680e71a927bc6cf94fb79c925ae Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Tue, 16 Jan 2024 14:33:09 -0400 Subject: [PATCH 089/155] Update models.py --- .../workloads/imagenet_vit/imagenet_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py index 469716d59..02d708da8 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -256,7 +256,7 @@ def __init__(self, nn.init.xavier_uniform_(self.probe.data) self.mha = MultiheadAttention( - self.width, num_heads=self.num_heads, self_attn=False, bias=False) + self.width, num_heads=self.num_heads, self_attn=False, bias=True) self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) From 245856760e6bc0d0c0114b96c27e335b5c891dc9 Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Tue, 16 Jan 2024 14:33:30 -0400 Subject: [PATCH 090/155] Update models.py --- .../workloads/imagenet_vit/imagenet_jax/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index d701e89ae..639800b44 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -154,7 +154,7 @@ def __call__(self, x): x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, - use_bias=False, + use_bias=True, kernel_init=nn.initializers.xavier_uniform())(probe, x) y = nn.LayerNorm()(x) From f8b65111b284bdef065d8e829bb03187676ecc5f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 17 Jan 2024 19:54:18 +0000 Subject: [PATCH 091/155] fix ssim calculation --- .../workloads/fastmri/fastmri_jax/ssim.py | 6 ++++++ .../workloads/fastmri/fastmri_pytorch/ssim.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py index 0ece3ffa9..3eec125c1 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py @@ -40,6 +40,12 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): logits = logits * std + mean targets = targets * std + mean ssims = jax.vmap(structural_similarity)(logits, targets, volume_max) + + # NOTE(kasimbeg): map out-of-bounds ssims to 1 and -1, the theoretical + # maximum and minimum values of SSIM. + ssims = jnp.where(ssims > 1, jnp.ones_like(ssims), ssims) + ssims = jnp.where(ssims < -1, jnp.ones_like(ssims) * -1, ssims) + return ssims diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py index ebee661c8..3e2f9221e 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py @@ -47,6 +47,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): logits = logits * std + mean targets = targets * std + mean ssims = torch.vmap(structural_similarity)(logits, targets, volume_max) + + ssims = torch.where(ssims > 1, torch.ones_like(ssims), ssims) + ssims = torch.where(ssims < -1, torch.ones_like(ssims) * -1, ssims) return ssims From 4d7a912d0311887864fd4d8313f1eb9ecf59e37d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 17 Jan 2024 19:56:08 +0000 Subject: [PATCH 092/155] add comment --- algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py | 4 ++-- .../workloads/fastmri/fastmri_pytorch/ssim.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py index 3eec125c1..e15b93616 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py @@ -41,11 +41,11 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): targets = targets * std + mean ssims = jax.vmap(structural_similarity)(logits, targets, volume_max) - # NOTE(kasimbeg): map out-of-bounds ssims to 1 and -1, the theoretical + # map out-of-bounds ssims to 1 and -1, the theoretical # maximum and minimum values of SSIM. ssims = jnp.where(ssims > 1, jnp.ones_like(ssims), ssims) ssims = jnp.where(ssims < -1, jnp.ones_like(ssims) * -1, ssims) - + return ssims diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py index 3e2f9221e..eff6fb62f 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py @@ -48,8 +48,11 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None): targets = targets * std + mean ssims = torch.vmap(structural_similarity)(logits, targets, volume_max) + # map out-of-bounds ssims to 1 and -1, the theoretical + # maximum and minimum values of SSIM. ssims = torch.where(ssims > 1, torch.ones_like(ssims), ssims) ssims = torch.where(ssims < -1, torch.ones_like(ssims) * -1, ssims) + return ssims From 1683ba37ea41faa69ad83a90d3cb044daad75004 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 18 Jan 2024 18:41:12 +0000 Subject: [PATCH 093/155] add variant scoring conditions --- scoring/performance_profile.py | 85 ++++++++++++++++++++++++++++------ scoring/score_submission.py | 11 ++++- 2 files changed, 81 insertions(+), 15 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 84788c7ae..9322dfaa7 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -36,16 +36,19 @@ import pandas as pd import algorithmic_efficiency.workloads.workloads as workloads_registry +from algorithmic_efficiency.workloads.workloads import get_base_workload_name from scoring import scoring_utils WORKLOADS = workloads_registry.WORKLOADS +BASE_WORKLOADS = workloads_registry.BASE_WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' # These global variables have to be set according to the current set of # workloads and rules for the scoring to be correct. # We do not use the workload registry since it contains test and development # workloads as well. -NUM_WORKLOADS = 8 +NUM_BASE_WORKLOADS = 8 +NUM_VARIANT_WORKLOADS = 6 NUM_TRIALS = 5 MIN_EVAL_METRICS = [ @@ -152,16 +155,17 @@ def get_index_that_reaches_target(workload_df, def get_times_for_submission(submission, - submission_tag, + submission_name, time_col='global_step', verbosity=1, - self_tuning_ruleset=False): + self_tuning_ruleset=False, + strict=False): """Get times to target for each workload in a submission. Args: submission: A DataFrame containing one row for each trial in each workload for a given submission. - submission_tag: Globally unique identified for a submission. + submission_name: Globally unique identified for a submission. time_col: A string indicating which column to use for time. verbosity: Debug level of information; choice of (1, 2, 3). @@ -169,16 +173,23 @@ def get_times_for_submission(submission, DataFrame with columns `submission`, `workload`, and time_col. """ workloads = [] - submission_name = submission_tag.split('.')[1] num_workloads = len(submission.groupby('workload')) - if num_workloads != NUM_WORKLOADS: - logging.warning(f'Expecting {NUM_WORKLOADS} workloads ' - f'but found {num_workloads} workloads.') + if num_workloads != NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS: + if strict: + raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') + logging.warning( + f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' + f'but found {num_workloads} workloads.') for workload, group in submission.groupby('workload'): num_trials = len(group) if num_trials != NUM_TRIALS and not self_tuning_ruleset: - logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') + if strict: + raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') + else: + logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) trial_idx, time_idx = get_index_that_reaches_target( @@ -202,12 +213,30 @@ def get_times_for_submission(submission, print(f' - {key}: {val}') else: print('Submission did not reach target') + df = pd.DataFrame.from_records(workloads) + print(df) df = df.pivot(index='submission', columns='workload', values=time_col) + print(time_col) return df +def variant_criteria_filter(base_workload, variant_workload): + + def filter(x): + try: + if x[variant_workload] == np.inf: + return np.inf + else: + return x[base_workload] + except KeyError as e: + print(x.keys()) + raise e + + return filter + + def compute_performance_profiles(results, time_col='global_step', min_tau=1.0, @@ -215,7 +244,9 @@ def compute_performance_profiles(results, reference_submission_tag=None, num_points=100, scale='linear', - verbosity=0): + verbosity=0, + strict=False, + self_tuning_ruleset=False): """Compute performance profiles for a set of submission by some time column. Args: @@ -247,9 +278,37 @@ def compute_performance_profiles(results, f'\nComputing performance profile with respect to `{time_col}` for ' f'{submission_tag}') dfs.append( - get_times_for_submission(result, submission_tag, time_col, verbosity)) + get_times_for_submission(result, + submission_tag, + time_col, + verbosity, + self_tuning_ruleset, + strict)) df = pd.concat(dfs) + # if strict: + + # Set score to inf if not within 4x of fastest submission + best_scores = df.min(axis=0) + df[df.apply(lambda x: x > 4 * best_scores, axis=1)] = np.inf + + # For each held-out workload if variant target was not hit set submission to inf + framework = None + for workload in df.keys(): + # Check if this is a variant + framework = workload.split('_')[-1] + workload_ = workload.split(f'_{framework}')[0] + if workload_ not in BASE_WORKLOADS: + # If variants do not have finite score set base_workload score to inf + base_workload = get_base_workload_name(workload_) + df[base_workload] = df.apply( + variant_criteria_filter(base_workload + f'_{framework}', workload), + axis=1) + + base_workloads = [w + f'_{framework}' for w in BASE_WORKLOADS] + df = df[base_workloads] + print(df) + if verbosity > 0: logging.info('\n`{time_col}` to reach target:') with pd.option_context('display.max_rows', @@ -288,7 +347,7 @@ def compute_performance_profiles(results, np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0) def rho(r, tau): - return (r <= tau).sum(axis=1) / NUM_WORKLOADS + return (r <= tau).sum(axis=1) / NUM_BASE_WORKLOADS perf_df = pd.concat([rho(df, tau) for tau in points], axis=1) diff --git a/scoring/score_submission.py b/scoring/score_submission.py index 0dd84ff55..e0a32777f 100644 --- a/scoring/score_submission.py +++ b/scoring/score_submission.py @@ -22,6 +22,11 @@ flags.DEFINE_boolean('compute_performance_profiles', False, 'Whether or not to compute the performance profiles.') +flags.DEFINE_boolean( + 'strict', + False, + 'Whether to enforce scoring criteria on variant' + 'performance and on 5-trial median performance') FLAGS = flags.FLAGS @@ -57,6 +62,7 @@ def main(_): results = { FLAGS.submission_tag: df, } + print(df) dfs = [] for workload, group in df.groupby('workload'): @@ -64,7 +70,7 @@ def main(_): dfs.append(summary_df) df = pd.concat(dfs) - logging.info(tabulate(df, headers='keys', tablefmt='psql')) + logging.info('\n' + tabulate(df, headers='keys', tablefmt='psql')) if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( @@ -75,7 +81,8 @@ def main(_): reference_submission_tag=None, num_points=100, scale='linear', - verbosity=0) + verbosity=0, + strict=FLAGS.strict) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( From 370687deb8abfcff8e9393755d6f80bf0f5d2d2b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 18 Jan 2024 18:43:28 +0000 Subject: [PATCH 094/155] add flag for self-tuning rulset --- scoring/score_submission.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scoring/score_submission.py b/scoring/score_submission.py index e0a32777f..eafb41ce6 100644 --- a/scoring/score_submission.py +++ b/scoring/score_submission.py @@ -27,6 +27,11 @@ False, 'Whether to enforce scoring criteria on variant' 'performance and on 5-trial median performance') +flags.DEFINE_boolean( + 'self_tuning_ruleset', + False, + 'Whether to score on self-tuning ruleset or externally tuned ruleset' +) FLAGS = flags.FLAGS @@ -82,6 +87,7 @@ def main(_): num_points=100, scale='linear', verbosity=0, + self_tuning_ruleset=FLAGS.self_tuning_ruleset, strict=FLAGS.strict) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) From 2128ce8bf7fd600e351f951f4fec5493414f7202 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 18 Jan 2024 18:51:00 +0000 Subject: [PATCH 095/155] score group of submissions --- scoring/score_submission.py | 103 ------------------------------------ 1 file changed, 103 deletions(-) delete mode 100644 scoring/score_submission.py diff --git a/scoring/score_submission.py b/scoring/score_submission.py deleted file mode 100644 index eafb41ce6..000000000 --- a/scoring/score_submission.py +++ /dev/null @@ -1,103 +0,0 @@ -import operator -import os - -from absl import app -from absl import flags -from absl import logging -import numpy as np -import pandas as pd -import scoring_utils -from tabulate import tabulate - -from scoring import performance_profile - -flags.DEFINE_string( - 'experiment_path', - None, - 'Path to experiment directory containing workload directories.') -flags.DEFINE_string('submission_tag', 'my.submission', 'Submission tag.') -flags.DEFINE_string('output_dir', - 'scoring_results', - 'Path to save performance profile table and plot.') -flags.DEFINE_boolean('compute_performance_profiles', - False, - 'Whether or not to compute the performance profiles.') -flags.DEFINE_boolean( - 'strict', - False, - 'Whether to enforce scoring criteria on variant' - 'performance and on 5-trial median performance') -flags.DEFINE_boolean( - 'self_tuning_ruleset', - False, - 'Whether to score on self-tuning ruleset or externally tuned ruleset' -) -FLAGS = flags.FLAGS - - -def get_summary_df(workload, workload_df): - validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) - is_minimized = performance_profile.check_if_minimized(validation_metric) - target_op = operator.le if is_minimized else operator.ge - best_op = min if is_minimized else max - idx_op = np.argmin if is_minimized else np.argmax - - summary_df = pd.DataFrame() - summary_df['workload'] = workload_df['workload'] - summary_df['trial'] = workload_df['trial'] - summary_df['target metric name'] = validation_metric - summary_df['target metric value'] = validation_target - - summary_df['target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) - summary_df['best target'] = workload_df[validation_metric].apply( - lambda x: best_op(x)) - workload_df['index best eval'] = workload_df[validation_metric].apply( - lambda x: idx_op(x)) - summary_df['submission time'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval']], axis=1) - summary_df['score'] = summary_df.apply( - lambda x: x['submission time'] if x['target reached'] else np.inf, axis=1) - - return summary_df - - -def main(_): - df = scoring_utils.get_experiment_df(FLAGS.experiment_path) - results = { - FLAGS.submission_tag: df, - } - print(df) - - dfs = [] - for workload, group in df.groupby('workload'): - summary_df = get_summary_df(workload, group) - dfs.append(summary_df) - - df = pd.concat(dfs) - logging.info('\n' + tabulate(df, headers='keys', tablefmt='psql')) - - if FLAGS.compute_performance_profiles: - performance_profile_df = performance_profile.compute_performance_profiles( - results, - time_col='score', - min_tau=1.0, - max_tau=None, - reference_submission_tag=None, - num_points=100, - scale='linear', - verbosity=0, - self_tuning_ruleset=FLAGS.self_tuning_ruleset, - strict=FLAGS.strict) - if not os.path.exists(FLAGS.output_dir): - os.mkdir(FLAGS.output_dir) - performance_profile.plot_performance_profiles( - performance_profile_df, 'score', save_dir=FLAGS.output_dir) - perf_df = tabulate( - performance_profile_df.T, headers='keys', tablefmt='psql') - logging.info(f'Performance profile:\n {perf_df}') - - -if __name__ == '__main__': - flags.mark_flag_as_required('experiment_path') - app.run(main) From beca0dcd77ca9f4c63145380845eb7f917a938da Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 22:03:19 +0000 Subject: [PATCH 096/155] update fastmri targets --- .../workloads/fastmri/fastmri_jax/workload.py | 16 +++++++++++++ .../fastmri/fastmri_pytorch/workload.py | 24 +++++++++++++++++++ .../workloads/fastmri/workload.py | 4 ++-- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py index e141d7447..cf596268d 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py @@ -184,6 +184,14 @@ def use_tanh(self) -> bool: """Whether or not to use tanh activations in the model.""" return True + @property + def validation_target_value(self) -> float: + return 0.717840 + + @property + def test_target_value(self) -> float: + return 0.734505 + class FastMRILayerNormWorkload(FastMRIWorkload): @@ -191,3 +199,11 @@ class FastMRILayerNormWorkload(FastMRIWorkload): def use_layer_norm(self) -> bool: """Whether or not to use tanh activations in the model.""" return True + + @property + def validation_target_value(self) -> float: + return 0.723284 + + @property + def test_target_value(self) -> float: + return 0.739996 diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 5c40d0bb8..d3f49eb1d 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -272,6 +272,14 @@ def num_pool_layers(self) -> bool: def num_channels(self) -> bool: """Whether or not to use tanh activations in the model.""" return 64 + + @property + def validation_target_value(self) -> float: + return 0.723559 + + @property + def test_target_value(self) -> float: + return 0.740726 class FastMRITanhWorkload(FastMRIWorkload): @@ -280,6 +288,14 @@ class FastMRITanhWorkload(FastMRIWorkload): def use_tanh(self) -> bool: """Whether or not to use tanh activations in the model.""" return True + + @property + def validation_target_value(self) -> float: + return 0.717840 + + @property + def test_target_value(self) -> float: + return 0.734505 class FastMRILayerNormWorkload(FastMRIWorkload): @@ -288,3 +304,11 @@ class FastMRILayerNormWorkload(FastMRIWorkload): def use_layer_norm(self) -> bool: """Whether or not to use tanh activations in the model.""" return True + + @property + def validation_target_value(self) -> float: + return 0.723284 + + @property + def test_target_value(self) -> float: + return 0.739996 diff --git a/algorithmic_efficiency/workloads/fastmri/workload.py b/algorithmic_efficiency/workloads/fastmri/workload.py index e3f66ee8a..a8fd1abbb 100644 --- a/algorithmic_efficiency/workloads/fastmri/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/workload.py @@ -39,14 +39,14 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - return 0.727120 + return 0.723653 def has_reached_test_target(self, eval_result: float) -> bool: return eval_result['test/ssim'] > self.test_target_value @property def test_target_value(self) -> float: - return 0.744296 + return 0.740633 @property def loss_type(self) -> spec.LossType: From 7831c3f94ef5f9e256677da9b6d71d2a1f6fac2b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 22:07:15 +0000 Subject: [PATCH 097/155] add ogbg variant targets --- .../workloads/ogbg/ogbg_jax/workload.py | 23 ++++++++++++++++++ .../workloads/ogbg/ogbg_pytorch/workload.py | 24 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index 9fc24552d..a377692bd 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -129,6 +129,14 @@ class OgbgGeluWorkload(OgbgWorkload): def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'gelu' + + @property + def validation_target_value(self) -> float: + return 0.27771 + + @property + def test_target_value(self) -> float: + return 0.262926 class OgbgSiluWorkload(OgbgWorkload): @@ -138,6 +146,13 @@ def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'silu' + @property + def validation_target_value(self) -> float: + return 0.282178 + + @property + def test_target_value(self) -> float: + return 0.272144 class OgbgModelSizeWorkload(OgbgWorkload): @@ -152,3 +167,11 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: return 3 + + @property + def validation_target_value(self) -> float: + return 0.269446 + + @property + def test_target_value(self) -> float: + return 0.253051 diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index beb518e0f..ec5db99a6 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -249,6 +249,14 @@ class OgbgGeluWorkload(OgbgWorkload): def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'gelu' + + @property + def validation_target_value(self) -> float: + return 0.27771 + + @property + def test_target_value(self) -> float: + return 0.262926 class OgbgSiluWorkload(OgbgWorkload): @@ -257,6 +265,14 @@ class OgbgSiluWorkload(OgbgWorkload): def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'silu' + + @property + def validation_target_value(self) -> float: + return 0.282178 + + @property + def test_target_value(self) -> float: + return 0.272144 class OgbgModelSizeWorkload(OgbgWorkload): @@ -272,3 +288,11 @@ def latent_dim(self) -> int: @property def num_message_passing_steps(self) -> int: return 3 + + @property + def validation_target_value(self) -> float: + return 0.269446 + + @property + def test_target_value(self) -> float: + return 0.253051 \ No newline at end of file From 13720e4108b900195503eda45913be38dd9ed373 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 22:11:39 +0000 Subject: [PATCH 098/155] add resnet variant targets --- .../imagenet_resnet/imagenet_jax/workload.py | 24 +++++++++++++++++++ .../imagenet_pytorch/workload.py | 16 +++++++++++++ 2 files changed, 40 insertions(+) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index cef615361..8522569e0 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -270,12 +270,28 @@ class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload): def use_silu(self) -> bool: return True + @property + def validation_target_value(self) -> float: + return 0.22009 + + @property + def test_target_value(self) -> float: + return 0.3426 + class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): @property def use_gelu(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.22077 + + @property + def test_target_value(self) -> float: + return 0.3402 class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): @@ -283,3 +299,11 @@ class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): @property def bn_init_scale(self) -> float: return 8.0 + + @property + def validation_target_value(self) -> float: + return 0.23474 + + @property + def test_target_value(self) -> float: + return 0.3577 diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index bcc1f87c5..4c74e0691 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -306,6 +306,14 @@ class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload): @property def use_silu(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.22009 + + @property + def test_target_value(self) -> float: + return 0.342 class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): @@ -313,6 +321,14 @@ class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): @property def use_gelu(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.22077 + + @property + def test_target_value(self) -> float: + return 0.3402 class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): From b5e3a99b8b49d8e47a8022ff738a2446df9abc94 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 22:30:40 +0000 Subject: [PATCH 099/155] add vit variant targets --- .../imagenet_vit/imagenet_jax/workload.py | 24 +++++++++++++++++++ .../imagenet_vit/imagenet_pytorch/workload.py | 17 +++++++++++++ 2 files changed, 41 insertions(+) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 4b12247c2..0bc1cd8cc 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -94,6 +94,14 @@ class ImagenetVitGluWorkload(ImagenetVitWorkload): def use_glu(self) -> bool: return True + @property + def validation_target_value(self) -> float: + return 0.2233 + + @property + def test_target_value(self) -> float: + return 0.3455 + class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @@ -101,9 +109,25 @@ class ImagenetVitPostLNWorkload(ImagenetVitWorkload): def use_post_layer_norm(self) -> bool: return True + @property + def validation_target_value(self) -> float: + return 0.24688 + + @property + def test_target_value(self) -> float: + return 0.3714 + class ImagenetVitMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.22886 + + @property + def test_target_value(self) -> float: + return 0.3477 diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 645b795ca..d77f9713b 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -87,6 +87,14 @@ class ImagenetVitGluWorkload(ImagenetVitWorkload): @property def use_glu(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.2233 + + @property + def test_target_value(self) -> float: + return 0.3455 class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @@ -101,3 +109,12 @@ class ImagenetVitMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.22886 + + @property + def test_target_value(self) -> float: + return 0.3477 + \ No newline at end of file From 0a3c1d2ff59862f1a6432bdcec940c58f31c5e2e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 22:37:26 +0000 Subject: [PATCH 100/155] add conformer variant targets --- .../librispeech_jax/workload.py | 24 +++++++++++++++++++ .../librispeech_pytorch/workload.py | 24 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f1267f15c..baf0418ae 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -372,12 +372,28 @@ class LibriSpeechConformerAttentionTemperatureWorkload( def attention_temperature(self) -> float: return 1.6 + @property + def validation_target_value(self) -> float: + return 0.082665 + + @property + def test_target_value(self) -> float: + return 0.50168 + class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): @property def use_post_layer_norm(self) -> bool: return False + + @property + def validation_target_value(self) -> float: + return 0.085371 + + @property + def test_target_value(self) -> float: + return 0.053096 class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): @@ -385,3 +401,11 @@ class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): @property def use_gelu(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.077958 + + @property + def test_target_value(self) -> float: + return 0.047643 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 18853d19c..dd013943a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -337,6 +337,14 @@ class LibriSpeechConformerAttentionTemperatureWorkload( @property def attention_temperature(self) -> float: return 1.6 + + @property + def validation_target_value(self) -> float: + return 0.082665 + + @property + def test_target_value(self) -> float: + return 0.050168 class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): @@ -344,6 +352,14 @@ class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): @property def use_post_layer_norm(self) -> bool: return False + + @property + def validation_target_value(self) -> float: + return 0.085371 + + @property + def test_target_value(self) -> float: + return 0.053096 class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): @@ -351,3 +367,11 @@ class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): @property def use_gelu(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.077958 + + @property + def test_target_value(self) -> float: + return 0.047643 From 550ff8a9ec203a4b3c541245a8ad1bcf11045da0 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 23:00:23 +0000 Subject: [PATCH 101/155] add fastmri workload variant target --- .../workloads/fastmri/fastmri_jax/workload.py | 8 ++++++++ .../workloads/fastmri/fastmri_pytorch/workload.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py index cf596268d..f1db484a0 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py @@ -175,6 +175,14 @@ def num_pool_layers(self) -> bool: def num_channels(self) -> bool: """Whether or not to use tanh activations in the model.""" return 64 + + @property + def validation_target_value(self) -> float: + return 0.723559 + + @property + def test_target_value(self) -> float: + return 0.740726 class FastMRITanhWorkload(FastMRIWorkload): diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index d3f49eb1d..777cf4bef 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -288,7 +288,7 @@ class FastMRITanhWorkload(FastMRIWorkload): def use_tanh(self) -> bool: """Whether or not to use tanh activations in the model.""" return True - + @property def validation_target_value(self) -> float: return 0.717840 @@ -304,7 +304,7 @@ class FastMRILayerNormWorkload(FastMRIWorkload): def use_layer_norm(self) -> bool: """Whether or not to use tanh activations in the model.""" return True - + @property def validation_target_value(self) -> float: return 0.723284 From aeb934e18ae27158b33dd722b7d3ea5f796b8392 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 23:01:08 +0000 Subject: [PATCH 102/155] reformat --- .../workloads/fastmri/fastmri_jax/workload.py | 2 +- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 ++-- .../workloads/imagenet_resnet/imagenet_jax/workload.py | 2 +- .../workloads/imagenet_resnet/imagenet_pytorch/workload.py | 4 ++-- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 3 +-- .../librispeech_conformer/librispeech_jax/workload.py | 2 +- .../librispeech_conformer/librispeech_pytorch/workload.py | 4 ++-- algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py | 3 ++- .../workloads/ogbg/ogbg_pytorch/workload.py | 6 +++--- 9 files changed, 15 insertions(+), 15 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py index f1db484a0..1476926e3 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py @@ -175,7 +175,7 @@ def num_pool_layers(self) -> bool: def num_channels(self) -> bool: """Whether or not to use tanh activations in the model.""" return 64 - + @property def validation_target_value(self) -> float: return 0.723559 diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 777cf4bef..74f6aa13d 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -272,7 +272,7 @@ def num_pool_layers(self) -> bool: def num_channels(self) -> bool: """Whether or not to use tanh activations in the model.""" return 64 - + @property def validation_target_value(self) -> float: return 0.723559 @@ -304,7 +304,7 @@ class FastMRILayerNormWorkload(FastMRIWorkload): def use_layer_norm(self) -> bool: """Whether or not to use tanh activations in the model.""" return True - + @property def validation_target_value(self) -> float: return 0.723284 diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 8522569e0..486bf7980 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -284,7 +284,7 @@ class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): @property def use_gelu(self) -> bool: return True - + @property def validation_target_value(self) -> float: return 0.22077 diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 4c74e0691..c3a000d9f 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -306,7 +306,7 @@ class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload): @property def use_silu(self) -> bool: return True - + @property def validation_target_value(self) -> float: return 0.22009 @@ -321,7 +321,7 @@ class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): @property def use_gelu(self) -> bool: return True - + @property def validation_target_value(self) -> float: return 0.22077 diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index d77f9713b..f7c9c12d1 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -87,7 +87,7 @@ class ImagenetVitGluWorkload(ImagenetVitWorkload): @property def use_glu(self) -> bool: return True - + @property def validation_target_value(self) -> float: return 0.2233 @@ -117,4 +117,3 @@ def validation_target_value(self) -> float: @property def test_target_value(self) -> float: return 0.3477 - \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index baf0418ae..32896aaa6 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -386,7 +386,7 @@ class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): @property def use_post_layer_norm(self) -> bool: return False - + @property def validation_target_value(self) -> float: return 0.085371 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index dd013943a..20f27b150 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -337,7 +337,7 @@ class LibriSpeechConformerAttentionTemperatureWorkload( @property def attention_temperature(self) -> float: return 1.6 - + @property def validation_target_value(self) -> float: return 0.082665 @@ -352,7 +352,7 @@ class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): @property def use_post_layer_norm(self) -> bool: return False - + @property def validation_target_value(self) -> float: return 0.085371 diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py index a377692bd..ec0c0658d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py @@ -129,7 +129,7 @@ class OgbgGeluWorkload(OgbgWorkload): def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'gelu' - + @property def validation_target_value(self) -> float: return 0.27771 @@ -154,6 +154,7 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: return 0.272144 + class OgbgModelSizeWorkload(OgbgWorkload): @property diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index ec5db99a6..d4817226d 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -249,7 +249,7 @@ class OgbgGeluWorkload(OgbgWorkload): def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'gelu' - + @property def validation_target_value(self) -> float: return 0.27771 @@ -265,7 +265,7 @@ class OgbgSiluWorkload(OgbgWorkload): def activation_fn_name(self) -> str: """Name of the activation function to use. One of 'relu', 'gelu', 'silu'.""" return 'silu' - + @property def validation_target_value(self) -> float: return 0.282178 @@ -295,4 +295,4 @@ def validation_target_value(self) -> float: @property def test_target_value(self) -> float: - return 0.253051 \ No newline at end of file + return 0.253051 From 454f50183a9ed5f152767790b3b95245fce39543 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 23:05:55 +0000 Subject: [PATCH 103/155] add resnet variant target --- .../imagenet_resnet/imagenet_pytorch/workload.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index c3a000d9f..a4d12a839 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -336,3 +336,11 @@ class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): @property def bn_init_scale(self) -> float: return 8.0 + + @property + def validation_target_value(self) -> float: + return 0.23474 + + @property + def test_target_value(self) -> float: + return 0.3577 From fde8bbcbe439773ca3f6c70c5748044a2dfbe5b5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 23:08:08 +0000 Subject: [PATCH 104/155] add resnet variant target --- .../imagenet_resnet/imagenet_pytorch/workload.py | 2 +- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index a4d12a839..6727054c9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -336,7 +336,7 @@ class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): @property def bn_init_scale(self) -> float: return 8.0 - + @property def validation_target_value(self) -> float: return 0.23474 diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index f7c9c12d1..f044d102f 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -102,6 +102,14 @@ class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True + + @property + def validation_target_value(self) -> float: + return 0.24688 + + @property + def test_target_value(self) -> float: + return 0.3714 class ImagenetVitMapWorkload(ImagenetVitWorkload): From 5f2b331651418a526b375a0258942966178280bf Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 23:09:43 +0000 Subject: [PATCH 105/155] formatting --- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index f044d102f..ff67477a5 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -102,7 +102,7 @@ class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True - + @property def validation_target_value(self) -> float: return 0.24688 From 6e1ff6b5d26ae18ce545dfdd6f52a8a710bbd3fe Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 18 Jan 2024 23:12:55 +0000 Subject: [PATCH 106/155] trailing whitespace --- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index ff67477a5..e672e8d22 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -102,7 +102,7 @@ class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True - + @property def validation_target_value(self) -> float: return 0.24688 From d43ccf4782637c51643d470d865abf203c29665d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 19 Jan 2024 23:11:25 +0000 Subject: [PATCH 107/155] correct max number of steps --- scoring/run_workloads.py | 141 +++++++++++++++++++++++++++++++++++ scoring/score_submissions.py | 104 ++++++++++++++++++++++++++ 2 files changed, 245 insertions(+) create mode 100644 scoring/run_workloads.py create mode 100644 scoring/score_submissions.py diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py new file mode 100644 index 000000000..1804d157d --- /dev/null +++ b/scoring/run_workloads.py @@ -0,0 +1,141 @@ +""" +Example Usage: +python run_all_workloads.py --framework jax \ +--experiment_basename my_first_experiment \ +--docker_image_url \ +--tag \ +--run_percentage 10 \ +--submission_path \ +--tuning_search_space +""" + +from absl import flags +from absl import app +import os +import docker +import time + + +flags.DEFINE_string('tag', None, 'Optional Docker image tag') +flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') +flags.DEFINE_integer('run_percentage', 100, 'Percentage of max num steps to run for.') +flags.DEFINE_string('experiment_basename', 'my_experiment', 'Name of top sub directory in experiment dir.') +flags.DEFINE_boolean('rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.') +flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') +flags.DEFINE_string('submission_path', + 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', + 'Path to reference submission.') +flags.DEFINE_string('tuning_search_space', + 'prize_qualification_baselines/external_tuning/tuning_search_space.json', + 'Path to tuning search space.') +flags.DEFINE_string('framework', + None, + 'Can be either PyTorch or JAX.') +flags.DEFINE_boolean('dry_run', False, 'Whether or not to actually run the command') + + +FLAGS = flags.FLAGS + + +DATASETS = ['imagenet', + 'fastmri', + 'ogbg', + 'wmt', + 'librispeech', + 'criteo1tb'] + +WORKLOADS = { + 'imagenet_resnet': {'max_steps': 186_666, + 'dataset': 'imagenet'}, + 'imagenet_vit': {'max_steps': 186_666, + 'dataset': 'imagenet'}, + 'fastmri': {'max_steps': 36_189, + 'dataset': 'fastmri'}, + 'ogbg': {'max_steps': 80_000, + 'dataset': 'ogbg'}, + 'wmt': {'max_steps': 133_333, + 'dataset': 'wmt'}, + 'librispeech_deepspeech': {'max_steps': 48_000, + 'dataset': 'librispeech'}, + 'criteo1tb': {'max_steps': 10_666, + 'dataset': 'criteo1tb'}, + 'librispeech_conformer': {'max_steps': 80_000, + 'dataset': 'librispeech'}, + } + +def container_running(): + docker_client = docker.from_env() + containers = docker_client.containers.list() + if len(containers) == 0: + return False + else: + return True + +def wait_until_container_not_running(sleep_interval=5*60): + while container_running(): + time.sleep(sleep_interval) + return + +def main(_): + framework = FLAGS.framework + algorithm = FLAGS.algorithm + tag = f':{FLAGS.tag}' if FLAGS.tag is not None else '' + run_fraction = FLAGS.run_percentage/100. + experiment_basename=FLAGS.experiment_basename + rsync_data = 'true' if FLAGS.rsync_data else 'false' + docker_image_url = FLAGS.docker_image_url + submission_path = FLAGS.submisison_path + tuning_search_space = FLAGS.tuning_search_space + + # For each runnable workload check if there are any containers running and if not launch next container command + for workload in WORKLOADS.keys(): + wait_until_container_not_running() + os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + print('='*100) + dataset = WORKLOADS[workload]['dataset'] + max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) + experiment_name = f'{experiment_basename}/{algorithm}' + mount_repo_flag = '' + if FLAGS.local: + mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' + command = ('docker run -t -d -v $HOME/data/:/data/ ' + '-v $HOME/experiment_runs/:/experiment_runs ' + '-v $HOME/experiment_runs/logs:/logs ' + f'{mount_repo_flag}' + '--gpus all --ipc=host ' + f'{docker_image_url}{tag} ' + f'-d {dataset} ' + f'-f {framework} ' + f'-s {submission_path} ' + f'-w {workload} ' + f'-t {tuning_search_space} ' + f'-e {experiment_name} ' + f'-m {max_steps} ' + '-c false ' + '-o true ' + f'-r {rsync_data} ' + '-i true ') + if not FLAGS.dry_run: + print('Running docker container command') + print('Container ID: ') + return_code = os.system(command) + else: + return_code = 0 + if return_code == 0: + print(f'SUCCESS: container for {framework} {workload} {algorithm} launched successfully') + print(f'Command: {command}') + print(f'Results will be logged to {experiment_name}') + else: + print(f'Failed: container for {framework} {workload} {algorithm} failed with exit code {return_code}.') + print(f'Command: {command}') + wait_until_container_not_running() + os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + + print('='*100) + + +if __name__ == '__main__': + flags.mark_flag_as_required('framework') + flags.mark_flag_as_required() + + app.run(main) \ No newline at end of file diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py new file mode 100644 index 000000000..13a0dc9b2 --- /dev/null +++ b/scoring/score_submissions.py @@ -0,0 +1,104 @@ +import operator +import os + +from absl import app +from absl import flags +from absl import logging +import numpy as np +import pandas as pd +import scoring_utils +from tabulate import tabulate + +from scoring import performance_profile + +flags.DEFINE_string( + 'submission_directory, + None, + 'Path to submission directory containing experiment directories.') +flags.DEFINE_string('output_dir', + 'scoring_results', + 'Path to save performance profile table and plot.') +flags.DEFINE_boolean('compute_performance_profiles', + False, + 'Whether or not to compute the performance profiles.') +flags.DEFINE_boolean( + 'strict', + False, + 'Whether to enforce scoring criteria on variant' + 'performance and on 5-trial median performance') +flags.DEFINE_boolean( + 'self_tuning_ruleset', + False, + 'Whether to score on self-tuning ruleset or externally tuned ruleset' +) +FLAGS = flags.FLAGS + + +def get_summary_df(workload, workload_df): + validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) + is_minimized = performance_profile.check_if_minimized(validation_metric) + target_op = operator.le if is_minimized else operator.ge + best_op = min if is_minimized else max + idx_op = np.argmin if is_minimized else np.argmax + + summary_df = pd.DataFrame() + summary_df['workload'] = workload_df['workload'] + summary_df['trial'] = workload_df['trial'] + summary_df['target metric name'] = validation_metric + summary_df['target metric value'] = validation_target + + summary_df['target reached'] = workload_df[validation_metric].apply( + lambda x: target_op(x, validation_target)).apply(np.any) + summary_df['best target'] = workload_df[validation_metric].apply( + lambda x: best_op(x)) + workload_df['index best eval'] = workload_df[validation_metric].apply( + lambda x: idx_op(x)) + summary_df['submission time'] = workload_df.apply( + lambda x: x['accumulated_submission_time'][x['index best eval']], axis=1) + summary_df['score'] = summary_df.apply( + lambda x: x['submission time'] if x['target reached'] else np.inf, axis=1) + + return summary_df + +def print_submission_summary(df): + dfs = [] + for workload, group in df.groupby('workload'): + summary_df = get_summary_df(workload, group) + dfs.append(summary_df) + + df = pd.concat(dfs) + logging.info('\n' + tabulate(df, headers='keys', tablefmt='psql')) + + +def main(_): + results = {} + + for submission in os.path.listdir(FLAGS.submission_directory): + df = scoring_utils.get_experiment_df(FLAGS.experiment_path) + results[submission] = df + print_submission_summary(df) + + if FLAGS.compute_performance_profiles: + performance_profile_df = performance_profile.compute_performance_profiles( + results, + time_col='score', + min_tau=1.0, + max_tau=None, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0, + self_tuning_ruleset=FLAGS.self_tuning_ruleset, + strict=FLAGS.strict) + if not os.path.exists(FLAGS.output_dir): + os.mkdir(FLAGS.output_dir) + performance_profile.plot_performance_profiles( + performance_profile_df, 'score', save_dir=FLAGS.output_dir) + perf_df = tabulate( + performance_profile_df.T, headers='keys', tablefmt='psql') + logging.info(f'Performance profile:\n {perf_df}') + + +if __name__ == '__main__': + flags.mark_flag_as_required('experiment_path') + app.run(main) From fb814362e5783441b6cf64dfd090f6626bb5cf0e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 19 Jan 2024 23:34:27 +0000 Subject: [PATCH 108/155] add heldout workloads" --- scoring/run_workloads.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 1804d157d..4df0c50ba 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -29,7 +29,7 @@ 'prize_qualification_baselines/external_tuning/tuning_search_space.json', 'Path to tuning search space.') flags.DEFINE_string('framework', - None, + 'jax', 'Can be either PyTorch or JAX.') flags.DEFINE_boolean('dry_run', False, 'Whether or not to actually run the command') @@ -63,6 +63,19 @@ 'dataset': 'librispeech'}, } + +HELDOUT_WORKLOADS = { + 'librispeech': ['librispeech_conformer_attention_temperature', 'librispeech_conformer_layernorm', + 'librispeech_conformer_gelu'], + 'imagenet': ['imagenet_resnet_silu', 'imagenet_resnet_gelu', 'imagenet_resnet_large_bn_init', + 'imagenet_vit_gelu', 'imagenet_vit_post_ln', 'imagenet_vit_map' + ], + 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], + 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], + 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'] + 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] +} + def container_running(): docker_client = docker.from_env() containers = docker_client.containers.list() @@ -135,7 +148,5 @@ def main(_): if __name__ == '__main__': - flags.mark_flag_as_required('framework') - flags.mark_flag_as_required() app.run(main) \ No newline at end of file From 2bd89b57033a3f807f7e0190d819b8968f0c71aa Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 19 Jan 2024 15:48:50 -0800 Subject: [PATCH 109/155] Update CHANGELOG.md --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ff1cc068..a0c8ae5ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Change Log +## algoperf-benchmark-0.1.1 (2024-01-19) +Bug fixes to FastMRI metric calculation and targets. + +Added workload variants and targets for ogbg, fastmri, librispeech_conformer, imagenet_resnet, imagenet_vit, criteo1tb to be used as held-out workloads. + ## algoperf-benchmark-0.1.0 (2023-11-28) First release of the AlgoPerf: Training algorithms benchmarking code. From 1ea2282f7d82b39363dfed32aef6af49f40dd130 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 23 Jan 2024 18:49:26 +0000 Subject: [PATCH 110/155] add trial args to docker startup.sh" --- docker/scripts/startup.sh | 18 ++++++++++++++++++ scoring/run_workloads.py | 24 +++++++++--------------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 30cb6b36b..2bd8abf33 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -44,6 +44,9 @@ HOME_DIR="" RSYNC_DATA="true" OVERWRITE="false" SAVE_CHECKPOINTS="true" +NUM_TUNING_TRIALS="1" +HPARAM_START_INDEX="None" +HPARAM_END_INDEX="None" # Pass flag while [ "$1" != "" ]; do @@ -100,6 +103,18 @@ while [ "$1" != "" ]; do shift HOME_DIR=$1 ;; + --num_tuning_trials) + shift + NUM_TUNING_TRIALS=$1 + ;; + --hparam_start_index) + shift + HPARAM_START_INDEX=$1 + ;; + --hparam_end_index) + shift + HPARAM_END_INDEX=$1 + ;; *) usage exit 1 @@ -204,6 +219,9 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --experiment_name=${EXPERIMENT_NAME} \ --overwrite=${OVERWRITE} \ --save_checkpoints=${SAVE_CHECKPOINTS} \ + --num_tuning_trials={NUM_TUNING_TRIALS} \ + --hparam_start_index={HPARAM_START_INDEX} \ + --hparam_end_index={HPARAM_END_INDEX} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ ${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}" diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 4df0c50ba..dff92aa86 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -1,7 +1,7 @@ """ Example Usage: -python run_all_workloads.py --framework jax \ ---experiment_basename my_first_experiment \ +python run_workloads.py --framework jax \ +--experiment_name my_first_experiment \ --docker_image_url \ --tag \ --run_percentage 10 \ @@ -16,10 +16,9 @@ import time -flags.DEFINE_string('tag', None, 'Optional Docker image tag') flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') flags.DEFINE_integer('run_percentage', 100, 'Percentage of max num steps to run for.') -flags.DEFINE_string('experiment_basename', 'my_experiment', 'Name of top sub directory in experiment dir.') +flags.DEFINE_string('experiment_name', 'my_experiment', 'Name of top sub directory in experiment dir.') flags.DEFINE_boolean('rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.') flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') flags.DEFINE_string('submission_path', @@ -72,7 +71,7 @@ ], 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], - 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'] + 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] } @@ -91,13 +90,10 @@ def wait_until_container_not_running(sleep_interval=5*60): def main(_): framework = FLAGS.framework - algorithm = FLAGS.algorithm - tag = f':{FLAGS.tag}' if FLAGS.tag is not None else '' run_fraction = FLAGS.run_percentage/100. - experiment_basename=FLAGS.experiment_basename - rsync_data = 'true' if FLAGS.rsync_data else 'false' + experiment_name=FLAGS.experiment_name docker_image_url = FLAGS.docker_image_url - submission_path = FLAGS.submisison_path + submission_path = FLAGS.submission_path tuning_search_space = FLAGS.tuning_search_space # For each runnable workload check if there are any containers running and if not launch next container command @@ -107,7 +103,6 @@ def main(_): print('='*100) dataset = WORKLOADS[workload]['dataset'] max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) - experiment_name = f'{experiment_basename}/{algorithm}' mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' @@ -116,7 +111,7 @@ def main(_): '-v $HOME/experiment_runs/logs:/logs ' f'{mount_repo_flag}' '--gpus all --ipc=host ' - f'{docker_image_url}{tag} ' + f'{docker_image_url} ' f'-d {dataset} ' f'-f {framework} ' f'-s {submission_path} ' @@ -126,7 +121,6 @@ def main(_): f'-m {max_steps} ' '-c false ' '-o true ' - f'-r {rsync_data} ' '-i true ') if not FLAGS.dry_run: print('Running docker container command') @@ -135,11 +129,11 @@ def main(_): else: return_code = 0 if return_code == 0: - print(f'SUCCESS: container for {framework} {workload} {algorithm} launched successfully') + print(f'SUCCESS: container for {framework} {workload} launched successfully') print(f'Command: {command}') print(f'Results will be logged to {experiment_name}') else: - print(f'Failed: container for {framework} {workload} {algorithm} failed with exit code {return_code}.') + print(f'Failed: container for {framework} {workload} failed with exit code {return_code}.') print(f'Command: {command}') wait_until_container_not_running() os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches From 0bcb9691a83ed292543a412d1e6e59b83b35fdd1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 24 Jan 2024 21:53:40 +0000 Subject: [PATCH 111/155] add script for sampling held out workloads --- scoring/generate_held_out_workloads.py | 72 +++++++++++++ scoring/run_workloads.py | 139 +++++++++++++++---------- 2 files changed, 155 insertions(+), 56 deletions(-) create mode 100644 scoring/generate_held_out_workloads.py diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py new file mode 100644 index 000000000..cc5c3df71 --- /dev/null +++ b/scoring/generate_held_out_workloads.py @@ -0,0 +1,72 @@ +from absl import app +from absl import flags +from absl import logging +import struct +import os + +import json +import jax +import jax.numpy as jnp +from algorithmic_efficiency import random_utils as prng + + +flags.DEFINE_integer('seed', None, 'Random seed for scoring.') +flags.DEFINE_string('framework', 'jax', "JAX or") +flags.DEFINE_string('output_filename', 'held_out_workloads.json', 'Path to file to record sampled held_out workloads.') +FLAGS = flags.FLAGS + + +HELD_OUT_WORKLOADS = { + 'librispeech': ['librispeech_conformer_attention_temperature', 'librispeech_conformer_layernorm', + 'librispeech_conformer_gelu'], + 'imagenet': ['imagenet_resnet_silu', 'imagenet_resnet_gelu', 'imagenet_resnet_large_bn_init', + 'imagenet_vit_gelu', 'imagenet_vit_post_ln', 'imagenet_vit_map' + ], + 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], + 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], + 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], + 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] +} + + +def save_held_out_workloads(held_out_workloads, filename): + with open(filename, "w") as f: + json.dump(held_out_workloads, f) + + +def read_held_out_workloads(filename): + with open(filename, "r") as f: + held_out_workloads = json.load(f) + return held_out_workloads + + + +def main(_): + rng_seed = FLAGS.seed + output_filename = FLAGS.output_filename + + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] + + logging.info('Using RNG seed %d', rng_seed) + rng_key = prng.PRNGKey(rng_seed) + + sampled_held_out_workloads = [] + for k, v in HELD_OUT_WORKLOADS.items(): + rng_key, rng_sub_key = prng.split(rng_key, 2) + p = jnp.array([1/len(v) for w in v]) + sampled_index = jax.random.categorical(rng_sub_key, p) + sampled_held_out_workloads.append(v[sampled_index]) + + logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") + + save_held_out_workloads(sampled_held_out_workloads, output_filename) + + +if __name__ == '__main__': + app.run(main) + + + + +print(h) \ No newline at end of file diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index dff92aa86..0f56ead78 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -11,9 +11,13 @@ from absl import flags from absl import app +from absl import logging import os import docker import time +import struct + +from algorithmic_efficiency import random_utils as prng flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') @@ -31,6 +35,15 @@ 'jax', 'Can be either PyTorch or JAX.') flags.DEFINE_boolean('dry_run', False, 'Whether or not to actually run the command') +flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') +flags.DEFINE_string('study_start_index', None, 'Start index for studies.') +flags.DEFINE_string('study_end_index', None, 'End index for studies.') +flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.') +flags.DEFINE_integer('hparam_start_index', None, 'Start index for tuning trials.') +flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') +flags.DEFINE_integer('seed', None, 'Random seed for scoring.') +flags.DEFINE_integer('submission_id', 0, 'Submission ID to generate study and hparam seeds.') +flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') FLAGS = flags.FLAGS @@ -63,18 +76,6 @@ } -HELDOUT_WORKLOADS = { - 'librispeech': ['librispeech_conformer_attention_temperature', 'librispeech_conformer_layernorm', - 'librispeech_conformer_gelu'], - 'imagenet': ['imagenet_resnet_silu', 'imagenet_resnet_gelu', 'imagenet_resnet_large_bn_init', - 'imagenet_vit_gelu', 'imagenet_vit_post_ln', 'imagenet_vit_map' - ], - 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], - 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], - 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], - 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] -} - def container_running(): docker_client = docker.from_env() containers = docker_client.containers.list() @@ -95,50 +96,76 @@ def main(_): docker_image_url = FLAGS.docker_image_url submission_path = FLAGS.submission_path tuning_search_space = FLAGS.tuning_search_space - - # For each runnable workload check if there are any containers running and if not launch next container command - for workload in WORKLOADS.keys(): - wait_until_container_not_running() - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - print('='*100) - dataset = WORKLOADS[workload]['dataset'] - max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) - mount_repo_flag = '' - if FLAGS.local: - mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' - command = ('docker run -t -d -v $HOME/data/:/data/ ' - '-v $HOME/experiment_runs/:/experiment_runs ' - '-v $HOME/experiment_runs/logs:/logs ' - f'{mount_repo_flag}' - '--gpus all --ipc=host ' - f'{docker_image_url} ' - f'-d {dataset} ' - f'-f {framework} ' - f'-s {submission_path} ' - f'-w {workload} ' - f'-t {tuning_search_space} ' - f'-e {experiment_name} ' - f'-m {max_steps} ' - '-c false ' - '-o true ' - '-i true ') - if not FLAGS.dry_run: - print('Running docker container command') - print('Container ID: ') - return_code = os.system(command) - else: - return_code = 0 - if return_code == 0: - print(f'SUCCESS: container for {framework} {workload} launched successfully') - print(f'Command: {command}') - print(f'Results will be logged to {experiment_name}') - else: - print(f'Failed: container for {framework} {workload} failed with exit code {return_code}.') - print(f'Command: {command}') - wait_until_container_not_running() - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - - print('='*100) + num_studies = FLAGS.num_studies + num_tuning_trials = FLAGS.num_tuning_trials + hparam_start_index = FLAGS.hparam_start_index + hparam_end_index = FLAGS.hparam_end_index + study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 + study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 + submission_id = FLAGS.submission_id + rng_seed = FLAGS.seed + + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] + + logging.info('Using RNG seed %d', rng_seed) + rng_key = prng.fold_in(prng.PRNGKey(rng_seed), submission_id) + rng_keys = prng.split(rng_key, 5) + + for study_index, rng_key in zip(range(study_start_index, study_end_index), rng_keys): + print('-' * 100) + print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) + print('-' * 100) + _, rng_seed = rng_key + study_dir = os.path.join(experiment_name, f'study_{index}') + + # For each runnable workload check if there are any containers running and if not launch next container command + for workload in WORKLOADS.keys(): + wait_until_container_not_running() + os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + print('='*100) + dataset = WORKLOADS[workload]['dataset'] + max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) + mount_repo_flag = '' + if FLAGS.local: + mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' + command = ('docker run -t -d -v $HOME/data/:/data/ ' + '-v $HOME/experiment_runs/:/experiment_runs ' + '-v $HOME/experiment_runs/logs:/logs ' + f'{mount_repo_flag}' + '--gpus all --ipc=host ' + f'{docker_image_url} ' + f'-d {dataset} ' + f'-f {framework} ' + f'-s {submission_path} ' + f'-w {workload} ' + f'-t {tuning_search_space} ' + f'-e {study_dir} ' + f'-m {max_steps} ' + f'--num_tuning_trials {num_tuning_trials} ' + f'--hparam_start_index {hparam_start_index} ' + f'--hparam_end_index {hparam_end_index} ' + f'--rng_seed {rng_seed} ' + '-c false ' + '-o true ' + '-i true ') + if not FLAGS.dry_run: + print('Running docker container command') + print('Container ID: ') + return_code = os.system(command) + else: + return_code = 0 + if return_code == 0: + print(f'SUCCESS: container for {framework} {workload} launched successfully') + print(f'Command: {command}') + print(f'Results will be logged to {experiment_name}') + else: + print(f'Failed: container for {framework} {workload} failed with exit code {return_code}.') + print(f'Command: {command}') + wait_until_container_not_running() + os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + + print('='*100) if __name__ == '__main__': From ce5f202e06c2d18200fc73e3a6ab6d6397b7fc88 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 01:05:26 +0000 Subject: [PATCH 112/155] add code for run workloads --- scoring/run_workloads.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 0f56ead78..b34f50ece 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -18,6 +18,7 @@ import struct from algorithmic_efficiency import random_utils as prng +from scoring.generate_held_out_workloads import read_held_out_workloads flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') From f431eefc9405adc6609127de32572d943c96435e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:30:43 +0000 Subject: [PATCH 113/155] add workload sampling --- scoring/generate_held_out_workloads.py | 18 +++---------- scoring/run_workloads.py | 37 +++++++++++++++++++------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index cc5c3df71..aa85b9d55 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -10,9 +10,9 @@ from algorithmic_efficiency import random_utils as prng -flags.DEFINE_integer('seed', None, 'Random seed for scoring.') -flags.DEFINE_string('framework', 'jax', "JAX or") +flags.DEFINE_integer('held_out_workloads_seed', None, 'Random seed for scoring.') flags.DEFINE_string('output_filename', 'held_out_workloads.json', 'Path to file to record sampled held_out workloads.') +flags.DEFINE_string('framework', 'jax', 'JAX or PyTorch') FLAGS = flags.FLAGS @@ -34,15 +34,8 @@ def save_held_out_workloads(held_out_workloads, filename): json.dump(held_out_workloads, f) -def read_held_out_workloads(filename): - with open(filename, "r") as f: - held_out_workloads = json.load(f) - return held_out_workloads - - - def main(_): - rng_seed = FLAGS.seed + rng_seed = FLAGS.held_out_workloads_seed output_filename = FLAGS.output_filename if not rng_seed: @@ -65,8 +58,3 @@ def main(_): if __name__ == '__main__': app.run(main) - - - - -print(h) \ No newline at end of file diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index b34f50ece..cfe545b42 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -16,9 +16,10 @@ import docker import time import struct +import json from algorithmic_efficiency import random_utils as prng -from scoring.generate_held_out_workloads import read_held_out_workloads +from algorithmic_efficiency.workloads.workloads import get_base_workload_name flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') @@ -77,6 +78,13 @@ } + +def read_held_out_workloads(filename): + with open(filename, "r") as f: + held_out_workloads = json.load(f) + return held_out_workloads + + def container_running(): docker_client = docker.from_env() containers = docker_client.containers.list() @@ -110,23 +118,32 @@ def main(_): rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) - rng_key = prng.fold_in(prng.PRNGKey(rng_seed), submission_id) - rng_keys = prng.split(rng_key, 5) + rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), submission_id)) + + workloads = [w for w in WORKLOADS.keys()] + + # Read held-out workloads + if FLAGS.held_out_workloads_config_path: + held_out_workloads = read_held_out_workloads(FLAGS.held_out_workloads_config_path) + workloads = workloads + held_out_workloads - for study_index, rng_key in zip(range(study_start_index, study_end_index), rng_keys): + for study_index in range(study_start_index, study_end_index): print('-' * 100) print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) print('-' * 100) - _, rng_seed = rng_key - study_dir = os.path.join(experiment_name, f'study_{index}') + rng_key, rng_subkey = prng.split(rng_key) + study_dir = os.path.join(experiment_name, f'study_{study_index}') # For each runnable workload check if there are any containers running and if not launch next container command - for workload in WORKLOADS.keys(): + for workload in workloads: + rng_subkey, run_key = prng.split(rng_subkey) + run_seed = run_key[0] # arbitrary + base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('='*100) - dataset = WORKLOADS[workload]['dataset'] - max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) + dataset = WORKLOADS[base_workload_name]['dataset'] + max_steps = int(WORKLOADS[base_workload_name]['max_steps'] * run_fraction) mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' @@ -146,7 +163,7 @@ def main(_): f'--num_tuning_trials {num_tuning_trials} ' f'--hparam_start_index {hparam_start_index} ' f'--hparam_end_index {hparam_end_index} ' - f'--rng_seed {rng_seed} ' + f'--rng_seed {run_seed} ' '-c false ' '-o true ' '-i true ') From f260497025cf7191cfc5883cbc75be46158e5732 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:38:59 +0000 Subject: [PATCH 114/155] formatting --- scoring/generate_held_out_workloads.py | 66 +++--- scoring/run_workloads.py | 301 +++++++++++++------------ 2 files changed, 194 insertions(+), 173 deletions(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index aa85b9d55..794a451c2 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -9,52 +9,64 @@ import jax.numpy as jnp from algorithmic_efficiency import random_utils as prng - -flags.DEFINE_integer('held_out_workloads_seed', None, 'Random seed for scoring.') -flags.DEFINE_string('output_filename', 'held_out_workloads.json', 'Path to file to record sampled held_out workloads.') +flags.DEFINE_integer('held_out_workloads_seed', + None, + 'Random seed for scoring.') +flags.DEFINE_string('output_filename', + 'held_out_workloads.json', + 'Path to file to record sampled held_out workloads.') flags.DEFINE_string('framework', 'jax', 'JAX or PyTorch') FLAGS = flags.FLAGS - HELD_OUT_WORKLOADS = { - 'librispeech': ['librispeech_conformer_attention_temperature', 'librispeech_conformer_layernorm', - 'librispeech_conformer_gelu'], - 'imagenet': ['imagenet_resnet_silu', 'imagenet_resnet_gelu', 'imagenet_resnet_large_bn_init', - 'imagenet_vit_gelu', 'imagenet_vit_post_ln', 'imagenet_vit_map' + 'librispeech': [ + 'librispeech_conformer_attention_temperature', + 'librispeech_conformer_layernorm', + 'librispeech_conformer_gelu' + ], + 'imagenet': [ + 'imagenet_resnet_silu', + 'imagenet_resnet_gelu', + 'imagenet_resnet_large_bn_init', + 'imagenet_vit_gelu', + 'imagenet_vit_post_ln', + 'imagenet_vit_map' ], 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], - 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] + 'criteo1tb': [ + 'criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet' + ] } def save_held_out_workloads(held_out_workloads, filename): - with open(filename, "w") as f: - json.dump(held_out_workloads, f) + with open(filename, "w") as f: + json.dump(held_out_workloads, f) def main(_): - rng_seed = FLAGS.held_out_workloads_seed - output_filename = FLAGS.output_filename + rng_seed = FLAGS.held_out_workloads_seed + output_filename = FLAGS.output_filename + + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] - if not rng_seed: - rng_seed = struct.unpack('I', os.urandom(4))[0] - - logging.info('Using RNG seed %d', rng_seed) - rng_key = prng.PRNGKey(rng_seed) + logging.info('Using RNG seed %d', rng_seed) + rng_key = prng.PRNGKey(rng_seed) - sampled_held_out_workloads = [] - for k, v in HELD_OUT_WORKLOADS.items(): - rng_key, rng_sub_key = prng.split(rng_key, 2) - p = jnp.array([1/len(v) for w in v]) - sampled_index = jax.random.categorical(rng_sub_key, p) - sampled_held_out_workloads.append(v[sampled_index]) + sampled_held_out_workloads = [] + for k, v in HELD_OUT_WORKLOADS.items(): + rng_key, rng_sub_key = prng.split(rng_key, 2) + p = jnp.array([1 / len(v) for w in v]) + sampled_index = jax.random.categorical(rng_sub_key, p) + sampled_held_out_workloads.append(v[sampled_index]) - logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") + logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") - save_held_out_workloads(sampled_held_out_workloads, output_filename) + save_held_out_workloads(sampled_held_out_workloads, output_filename) if __name__ == '__main__': - app.run(main) + app.run(main) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index cfe545b42..4f72ebedb 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -14,178 +14,187 @@ from absl import logging import os import docker -import time +import time import struct import json from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency.workloads.workloads import get_base_workload_name - -flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') -flags.DEFINE_integer('run_percentage', 100, 'Percentage of max num steps to run for.') -flags.DEFINE_string('experiment_name', 'my_experiment', 'Name of top sub directory in experiment dir.') -flags.DEFINE_boolean('rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.') +flags.DEFINE_string( + 'docker_image_url', + 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', + 'URL to docker image') +flags.DEFINE_integer('run_percentage', + 100, + 'Percentage of max num steps to run for.') +flags.DEFINE_string('experiment_name', + 'my_experiment', + 'Name of top sub directory in experiment dir.') +flags.DEFINE_boolean('rsync_data', + True, + 'Whether or not to transfer the data from GCP w rsync.') flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') -flags.DEFINE_string('submission_path', - 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', - 'Path to reference submission.') -flags.DEFINE_string('tuning_search_space', - 'prize_qualification_baselines/external_tuning/tuning_search_space.json', - 'Path to tuning search space.') -flags.DEFINE_string('framework', - 'jax', - 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean('dry_run', False, 'Whether or not to actually run the command') +flags.DEFINE_string( + 'submission_path', + 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', + 'Path to reference submission.') +flags.DEFINE_string( + 'tuning_search_space', + 'prize_qualification_baselines/external_tuning/tuning_search_space.json', + 'Path to tuning search space.') +flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') +flags.DEFINE_boolean('dry_run', + False, + 'Whether or not to actually run the command') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_string('study_start_index', None, 'Start index for studies.') flags.DEFINE_string('study_end_index', None, 'End index for studies.') flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.') -flags.DEFINE_integer('hparam_start_index', None, 'Start index for tuning trials.') +flags.DEFINE_integer('hparam_start_index', + None, + 'Start index for tuning trials.') flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') flags.DEFINE_integer('seed', None, 'Random seed for scoring.') -flags.DEFINE_integer('submission_id', 0, 'Submission ID to generate study and hparam seeds.') -flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') - +flags.DEFINE_integer('submission_id', + 0, + 'Submission ID to generate study and hparam seeds.') +flags.DEFINE_string('held_out_workloads_config_path', + None, + 'Path to config containing held-out workloads') FLAGS = flags.FLAGS - -DATASETS = ['imagenet', - 'fastmri', - 'ogbg', - 'wmt', - 'librispeech', - 'criteo1tb'] +DATASETS = ['imagenet', 'fastmri', 'ogbg', 'wmt', 'librispeech', 'criteo1tb'] WORKLOADS = { - 'imagenet_resnet': {'max_steps': 186_666, - 'dataset': 'imagenet'}, - 'imagenet_vit': {'max_steps': 186_666, - 'dataset': 'imagenet'}, - 'fastmri': {'max_steps': 36_189, - 'dataset': 'fastmri'}, - 'ogbg': {'max_steps': 80_000, - 'dataset': 'ogbg'}, - 'wmt': {'max_steps': 133_333, - 'dataset': 'wmt'}, - 'librispeech_deepspeech': {'max_steps': 48_000, - 'dataset': 'librispeech'}, - 'criteo1tb': {'max_steps': 10_666, - 'dataset': 'criteo1tb'}, - 'librispeech_conformer': {'max_steps': 80_000, - 'dataset': 'librispeech'}, - } - + 'imagenet_resnet': {'max_steps': 186_666, 'dataset': 'imagenet'}, + 'imagenet_vit': {'max_steps': 186_666, 'dataset': 'imagenet'}, + 'fastmri': {'max_steps': 36_189, 'dataset': 'fastmri'}, + 'ogbg': {'max_steps': 80_000, 'dataset': 'ogbg'}, + 'wmt': {'max_steps': 133_333, 'dataset': 'wmt'}, + 'librispeech_deepspeech': {'max_steps': 48_000, 'dataset': 'librispeech'}, + 'criteo1tb': {'max_steps': 10_666, 'dataset': 'criteo1tb'}, + 'librispeech_conformer': {'max_steps': 80_000, 'dataset': 'librispeech'}, +} def read_held_out_workloads(filename): - with open(filename, "r") as f: - held_out_workloads = json.load(f) - return held_out_workloads + with open(filename, "r") as f: + held_out_workloads = json.load(f) + return held_out_workloads def container_running(): - docker_client = docker.from_env() - containers = docker_client.containers.list() - if len(containers) == 0: - return False - else: - return True - -def wait_until_container_not_running(sleep_interval=5*60): - while container_running(): - time.sleep(sleep_interval) - return - + docker_client = docker.from_env() + containers = docker_client.containers.list() + if len(containers) == 0: + return False + else: + return True + + +def wait_until_container_not_running(sleep_interval=5 * 60): + while container_running(): + time.sleep(sleep_interval) + return + + def main(_): - framework = FLAGS.framework - run_fraction = FLAGS.run_percentage/100. - experiment_name=FLAGS.experiment_name - docker_image_url = FLAGS.docker_image_url - submission_path = FLAGS.submission_path - tuning_search_space = FLAGS.tuning_search_space - num_studies = FLAGS.num_studies - num_tuning_trials = FLAGS.num_tuning_trials - hparam_start_index = FLAGS.hparam_start_index - hparam_end_index = FLAGS.hparam_end_index - study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 - study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 - submission_id = FLAGS.submission_id - rng_seed = FLAGS.seed - - if not rng_seed: - rng_seed = struct.unpack('I', os.urandom(4))[0] - - logging.info('Using RNG seed %d', rng_seed) - rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), submission_id)) - - workloads = [w for w in WORKLOADS.keys()] - - # Read held-out workloads - if FLAGS.held_out_workloads_config_path: - held_out_workloads = read_held_out_workloads(FLAGS.held_out_workloads_config_path) - workloads = workloads + held_out_workloads - - for study_index in range(study_start_index, study_end_index): - print('-' * 100) - print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) - print('-' * 100) - rng_key, rng_subkey = prng.split(rng_key) - study_dir = os.path.join(experiment_name, f'study_{study_index}') - - # For each runnable workload check if there are any containers running and if not launch next container command - for workload in workloads: - rng_subkey, run_key = prng.split(rng_subkey) - run_seed = run_key[0] # arbitrary - base_workload_name = get_base_workload_name(workload) - wait_until_container_not_running() - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - print('='*100) - dataset = WORKLOADS[base_workload_name]['dataset'] - max_steps = int(WORKLOADS[base_workload_name]['max_steps'] * run_fraction) - mount_repo_flag = '' - if FLAGS.local: - mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' - command = ('docker run -t -d -v $HOME/data/:/data/ ' - '-v $HOME/experiment_runs/:/experiment_runs ' - '-v $HOME/experiment_runs/logs:/logs ' - f'{mount_repo_flag}' - '--gpus all --ipc=host ' - f'{docker_image_url} ' - f'-d {dataset} ' - f'-f {framework} ' - f'-s {submission_path} ' - f'-w {workload} ' - f'-t {tuning_search_space} ' - f'-e {study_dir} ' - f'-m {max_steps} ' - f'--num_tuning_trials {num_tuning_trials} ' - f'--hparam_start_index {hparam_start_index} ' - f'--hparam_end_index {hparam_end_index} ' - f'--rng_seed {run_seed} ' - '-c false ' - '-o true ' - '-i true ') - if not FLAGS.dry_run: - print('Running docker container command') - print('Container ID: ') - return_code = os.system(command) - else: - return_code = 0 - if return_code == 0: - print(f'SUCCESS: container for {framework} {workload} launched successfully') - print(f'Command: {command}') - print(f'Results will be logged to {experiment_name}') - else: - print(f'Failed: container for {framework} {workload} failed with exit code {return_code}.') - print(f'Command: {command}') - wait_until_container_not_running() - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - - print('='*100) + framework = FLAGS.framework + run_fraction = FLAGS.run_percentage / 100. + experiment_name = FLAGS.experiment_name + docker_image_url = FLAGS.docker_image_url + submission_path = FLAGS.submission_path + tuning_search_space = FLAGS.tuning_search_space + num_studies = FLAGS.num_studies + num_tuning_trials = FLAGS.num_tuning_trials + hparam_start_index = FLAGS.hparam_start_index + hparam_end_index = FLAGS.hparam_end_index + study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 + study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 + submission_id = FLAGS.submission_id + rng_seed = FLAGS.seed + + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] + + logging.info('Using RNG seed %d', rng_seed) + rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), submission_id)) + + workloads = [w for w in WORKLOADS.keys()] + + # Read held-out workloads + if FLAGS.held_out_workloads_config_path: + held_out_workloads = read_held_out_workloads( + FLAGS.held_out_workloads_config_path) + workloads = workloads + held_out_workloads + + for study_index in range(study_start_index, study_end_index): + print('-' * 100) + print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) + print('-' * 100) + rng_key, rng_subkey = prng.split(rng_key) + study_dir = os.path.join(experiment_name, f'study_{study_index}') + + # For each runnable workload check if there are any containers running and if not launch next container command + for workload in workloads: + rng_subkey, run_key = prng.split(rng_subkey) + run_seed = run_key[0] # arbitrary + base_workload_name = get_base_workload_name(workload) + wait_until_container_not_running() + os.system( + "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + print('=' * 100) + dataset = WORKLOADS[base_workload_name]['dataset'] + max_steps = int(WORKLOADS[base_workload_name]['max_steps'] * run_fraction) + mount_repo_flag = '' + if FLAGS.local: + mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' + command = ('docker run -t -d -v $HOME/data/:/data/ ' + '-v $HOME/experiment_runs/:/experiment_runs ' + '-v $HOME/experiment_runs/logs:/logs ' + f'{mount_repo_flag}' + '--gpus all --ipc=host ' + f'{docker_image_url} ' + f'-d {dataset} ' + f'-f {framework} ' + f'-s {submission_path} ' + f'-w {workload} ' + f'-t {tuning_search_space} ' + f'-e {study_dir} ' + f'-m {max_steps} ' + f'--num_tuning_trials {num_tuning_trials} ' + f'--hparam_start_index {hparam_start_index} ' + f'--hparam_end_index {hparam_end_index} ' + f'--rng_seed {run_seed} ' + '-c false ' + '-o true ' + '-i true ') + if not FLAGS.dry_run: + print('Running docker container command') + print('Container ID: ') + return_code = os.system(command) + else: + return_code = 0 + if return_code == 0: + print( + f'SUCCESS: container for {framework} {workload} launched successfully' + ) + print(f'Command: {command}') + print(f'Results will be logged to {experiment_name}') + else: + print( + f'Failed: container for {framework} {workload} failed with exit code {return_code}.' + ) + print(f'Command: {command}') + wait_until_container_not_running() + os.system( + "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + + print('=' * 100) if __name__ == '__main__': - app.run(main) \ No newline at end of file + app.run(main) From 1a41f8b82957012c2a13ae9d76d61be91348e061 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:39:28 +0000 Subject: [PATCH 115/155] imports --- scoring/generate_held_out_workloads.py | 9 +++++---- scoring/performance_profile.py | 2 +- scoring/run_workloads.py | 13 +++++++------ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index 794a451c2..c61e637bd 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -1,12 +1,13 @@ +import json +import os +import struct + from absl import app from absl import flags from absl import logging -import struct -import os - -import json import jax import jax.numpy as jnp + from algorithmic_efficiency import random_utils as prng flags.DEFINE_integer('held_out_workloads_seed', diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 9322dfaa7..ef4e97f88 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -35,8 +35,8 @@ import numpy as np import pandas as pd -import algorithmic_efficiency.workloads.workloads as workloads_registry from algorithmic_efficiency.workloads.workloads import get_base_workload_name +import algorithmic_efficiency.workloads.workloads as workloads_registry from scoring import scoring_utils WORKLOADS = workloads_registry.WORKLOADS diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 4f72ebedb..7ccd0ca9b 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -9,17 +9,18 @@ --tuning_search_space """ -from absl import flags -from absl import app -from absl import logging +import json import os -import docker -import time import struct -import json +import time + +from absl import app +from absl import flags +from absl import logging from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency.workloads.workloads import get_base_workload_name +import docker flags.DEFINE_string( 'docker_image_url', From 87df162a3ff5e2ae6b04b74fcbd8226016caeea8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:47:29 +0000 Subject: [PATCH 116/155] make seed splitting parallelizable --- scoring/run_workloads.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 7ccd0ca9b..f285d66d4 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -131,11 +131,12 @@ def main(_): FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - for study_index in range(study_start_index, study_end_index): + rng_subkeys = prng.split(rng_key, num_studies)[study_start_index:study_end_index:] + + for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) print('-' * 100) - rng_key, rng_subkey = prng.split(rng_key) study_dir = os.path.join(experiment_name, f'study_{study_index}') # For each runnable workload check if there are any containers running and if not launch next container command From 9d9cdb9dab98efb0d0f90a7cc1e813a89bd8d95a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:49:46 +0000 Subject: [PATCH 117/155] fix --- scoring/score_submissions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 13a0dc9b2..1f7a3a1e7 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -12,7 +12,7 @@ from scoring import performance_profile flags.DEFINE_string( - 'submission_directory, + 'submission_directory', None, 'Path to submission directory containing experiment directories.') flags.DEFINE_string('output_dir', From 17753071c72770862b9db08ed9493546689c2126 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:50:49 +0000 Subject: [PATCH 118/155] formatting --- scoring/run_workloads.py | 3 ++- scoring/score_submissions.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index f285d66d4..47a47ca58 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -131,7 +131,8 @@ def main(_): FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - rng_subkeys = prng.split(rng_key, num_studies)[study_start_index:study_end_index:] + rng_subkeys = prng.split(rng_key, + num_studies)[study_start_index:study_end_index:] for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 1f7a3a1e7..106c6b1da 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -29,8 +29,7 @@ flags.DEFINE_boolean( 'self_tuning_ruleset', False, - 'Whether to score on self-tuning ruleset or externally tuned ruleset' -) + 'Whether to score on self-tuning ruleset or externally tuned ruleset') FLAGS = flags.FLAGS @@ -60,6 +59,7 @@ def get_summary_df(workload, workload_df): return summary_df + def print_submission_summary(df): dfs = [] for workload, group in df.groupby('workload'): From 2a1170858bf003077c41a42e08daad1571ef37d3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 22:41:56 +0000 Subject: [PATCH 119/155] held out workloads example --- scoring/held_out_workloads_example.json | 1 + 1 file changed, 1 insertion(+) create mode 100644 scoring/held_out_workloads_example.json diff --git a/scoring/held_out_workloads_example.json b/scoring/held_out_workloads_example.json new file mode 100644 index 000000000..2b3d6d6b2 --- /dev/null +++ b/scoring/held_out_workloads_example.json @@ -0,0 +1 @@ +["librispeech_conformer_gelu", "imagenet_resnet_silu", "ogbg_gelu", "wmt_post_ln", "fastmri_model_size", "criteo1tb_layernorm"] \ No newline at end of file From a8385a21a3c402dc643bd27e31260343ef81e3b5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 22:44:31 +0000 Subject: [PATCH 120/155] add docker for run_workloads.py --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 20139d4c0..4fa84951f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ setup_requires = # Dependencies of the project: install_requires = absl-py==1.4.0 + docker==7.0.0 numpy>=1.23 pandas>=2.0.1 tensorflow==2.12.0 From ffddbdc1cb7020f16567ee4b4778da353ca2bdb6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 22:52:23 +0000 Subject: [PATCH 121/155] fix run_workloads.py --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 47a47ca58..6bf09469c 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -132,7 +132,7 @@ def main(_): workloads = workloads + held_out_workloads rng_subkeys = prng.split(rng_key, - num_studies)[study_start_index:study_end_index:] + num_studies)[:num_studies:] for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) From 91cdf34351d0f9213918495a14f362fb57e5ad7d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 22:53:19 +0000 Subject: [PATCH 122/155] fix --- scoring/run_workloads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 6bf09469c..19291ead0 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -49,8 +49,8 @@ False, 'Whether or not to actually run the command') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') -flags.DEFINE_string('study_start_index', None, 'Start index for studies.') -flags.DEFINE_string('study_end_index', None, 'End index for studies.') +flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') +flags.DEFINE_integer('study_end_index', None, 'End index for studies.') flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.') flags.DEFINE_integer('hparam_start_index', None, From 95572ad3c1af539b0e93601c9108740aca25e22e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 23:38:27 +0000 Subject: [PATCH 123/155] add rng seed to startup.sh docker script --- docker/scripts/startup.sh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 2bd8abf33..b06375c34 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -34,6 +34,10 @@ function usage() { from internal GCP bucket. -i | --internal_contributor: If true, allow rsync of data and transfer of experiment results with GCP project. + --num_tuning_trials Number of tuning trials for externally tuned ruleset submission. + --hparam_start_index Should be > 0 and < num_tuning_trials - 1. + --hparam_end_index Should be > 0 and < num_tuning_trials - 1. + --rng_seed RNG seed to pass to workload submission_runner. USAGE exit 1 } @@ -47,6 +51,7 @@ SAVE_CHECKPOINTS="true" NUM_TUNING_TRIALS="1" HPARAM_START_INDEX="None" HPARAM_END_INDEX="None" +RNG_SEED="None" # Pass flag while [ "$1" != "" ]; do @@ -115,6 +120,10 @@ while [ "$1" != "" ]; do shift HPARAM_END_INDEX=$1 ;; + --rng_seed) + shift + RNG_SEED=$1 + ;; *) usage exit 1 @@ -222,6 +231,7 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --num_tuning_trials={NUM_TUNING_TRIALS} \ --hparam_start_index={HPARAM_START_INDEX} \ --hparam_end_index={HPARAM_END_INDEX} \ + --rng_seed={RNG_SEED} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ ${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}" From d577d5c5758897937a590288f586cf01307f1267 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 23:51:13 +0000 Subject: [PATCH 124/155] fix --- docker/scripts/startup.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index b06375c34..b5ad18941 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -228,9 +228,9 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --experiment_name=${EXPERIMENT_NAME} \ --overwrite=${OVERWRITE} \ --save_checkpoints=${SAVE_CHECKPOINTS} \ - --num_tuning_trials={NUM_TUNING_TRIALS} \ - --hparam_start_index={HPARAM_START_INDEX} \ - --hparam_end_index={HPARAM_END_INDEX} \ + --num_tuning_trials=${NUM_TUNING_TRIALS} \ + --hparam_start_index=${HPARAM_START_INDEX} \ + --hparam_end_index=${HPARAM_END_INDEX} \ --rng_seed={RNG_SEED} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ From 91ff705f67e2b7ce08d12d150e323b6cf0fa6e7b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 23:53:06 +0000 Subject: [PATCH 125/155] fix --- docker/scripts/startup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index b5ad18941..c0328ffb4 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -231,7 +231,7 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --num_tuning_trials=${NUM_TUNING_TRIALS} \ --hparam_start_index=${HPARAM_START_INDEX} \ --hparam_end_index=${HPARAM_END_INDEX} \ - --rng_seed={RNG_SEED} \ + --rng_seed=${RNG_SEED} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ ${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}" From 296dc1ecc3ecb6035ce3b9faefbe84ebf89f6fdc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:07:52 +0000 Subject: [PATCH 126/155] fix --- docker/scripts/startup.sh | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index c0328ffb4..b4eff52ff 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -48,10 +48,6 @@ HOME_DIR="" RSYNC_DATA="true" OVERWRITE="false" SAVE_CHECKPOINTS="true" -NUM_TUNING_TRIALS="1" -HPARAM_START_INDEX="None" -HPARAM_END_INDEX="None" -RNG_SEED="None" # Pass flag while [ "$1" != "" ]; do @@ -204,6 +200,22 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then MAX_STEPS_FLAG="--max_global_steps=${MAX_GLOBAL_STEPS}" fi + if [[ ! -z ${NUM_TUNING_TRIALS+x} ]]; then + NUM_TUNING_TRIALS_FLAG="--num_tuning_trials=${NUM_TUNING_TRIALS}" + fi + + if [[ ! -z ${HPARAM_START_INDEX+x} ]]; then + HPARAM_START_INDEX_FLAG="--hparam_start_index=${HPARAM_START_INDEX}" + fi + + if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then + HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}" + fi + + if [[ ! -z ${RNG_SEED+x} ]]; then + RNG_SEED_FLAG="--rng_seed=${RNG_SEED}" + fi + # Define special flags for imagenet and librispeech workloads if [[ ${DATASET} == "imagenet" ]]; then SPECIAL_FLAGS="--imagenet_v2_data_dir=${DATA_DIR}" @@ -228,10 +240,10 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --experiment_name=${EXPERIMENT_NAME} \ --overwrite=${OVERWRITE} \ --save_checkpoints=${SAVE_CHECKPOINTS} \ - --num_tuning_trials=${NUM_TUNING_TRIALS} \ - --hparam_start_index=${HPARAM_START_INDEX} \ - --hparam_end_index=${HPARAM_END_INDEX} \ - --rng_seed=${RNG_SEED} \ + ${NUM_TUNING_TRIALS_FLAG} \ + ${HPARAM_START_INDEX_FLAG} \ + ${HPARAM_END_INDEX_FLAG} \ + ${RNG_SEED_FLAG} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ ${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}" From a5b1154343e96b822afd9abcc47ae005728f12f6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:30:14 +0000 Subject: [PATCH 127/155] fix --- .../external_tuning/tuning_search_space.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json index 65562905a..b5aff94a2 100644 --- a/prize_qualification_baselines/external_tuning/tuning_search_space.json +++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -10,6 +10,7 @@ }, { "dropout_rate": 0.0, + "label_smoothing": 0.1, "label_smoothing": 0.2, "learning_rate": 0.0008445074561975979, "one_minus_beta1": 0.11042418465, @@ -19,6 +20,7 @@ }, { "dropout_rate": 0.0, + "label_smoothing": 0.1, "learning_rate": 0.001308209823469072, "one_minus_beta1": 0.02686663061, "beta2": 0.9981232922116359, @@ -27,6 +29,7 @@ }, { "dropout_rate": 0.0, + "label_smoothing": 0.1, "learning_rate": 0.004958460849689891, "one_minus_beta1": 0.13625575743, "beta2": 0.6291854735396584, From 226544dc816c719334f08096f394a80c8f168a4e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:30:25 +0000 Subject: [PATCH 128/155] fix --- .../external_tuning/tuning_search_space.json | 1 + 1 file changed, 1 insertion(+) diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json index b5aff94a2..910b9a70a 100644 --- a/prize_qualification_baselines/external_tuning/tuning_search_space.json +++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -38,6 +38,7 @@ }, { "dropout_rate": 0.1, + "label_smoothing": 0.1, "learning_rate": 0.0017486387539278373, "one_minus_beta1": 0.06733926164, "beta2": 0.9955159689799007, From 6faad0431001a9a645dd4ff45e11febdbff5eb92 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:31:54 +0000 Subject: [PATCH 129/155] fix --- .../external_tuning/tuning_search_space.json | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json index 910b9a70a..199f77041 100644 --- a/prize_qualification_baselines/external_tuning/tuning_search_space.json +++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -10,7 +10,6 @@ }, { "dropout_rate": 0.0, - "label_smoothing": 0.1, "label_smoothing": 0.2, "learning_rate": 0.0008445074561975979, "one_minus_beta1": 0.11042418465, @@ -20,7 +19,7 @@ }, { "dropout_rate": 0.0, - "label_smoothing": 0.1, + "label_smoothing": 0.0, "learning_rate": 0.001308209823469072, "one_minus_beta1": 0.02686663061, "beta2": 0.9981232922116359, @@ -29,7 +28,7 @@ }, { "dropout_rate": 0.0, - "label_smoothing": 0.1, + "label_smoothing": 0.0, "learning_rate": 0.004958460849689891, "one_minus_beta1": 0.13625575743, "beta2": 0.6291854735396584, @@ -38,7 +37,7 @@ }, { "dropout_rate": 0.1, - "label_smoothing": 0.1, + "label_smoothing": 0.0, "learning_rate": 0.0017486387539278373, "one_minus_beta1": 0.06733926164, "beta2": 0.9955159689799007, From 9e7def9ef7d1ab0e502a3225dc532ab73df162e0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:56:34 +0000 Subject: [PATCH 130/155] fix log message --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 19291ead0..f04e8f8df 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -136,7 +136,7 @@ def main(_): for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) - print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) + print('*' * 40, f'Starting study {study_index}/{num_studies - 1}', '*' * 40) print('-' * 100) study_dir = os.path.join(experiment_name, f'study_{study_index}') From 9b410b71cc75b1ba11017fa679f9582cc71b3a5f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:25:30 +0000 Subject: [PATCH 131/155] fix --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index f04e8f8df..72b43dd9f 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -136,7 +136,7 @@ def main(_): for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) - print('*' * 40, f'Starting study {study_index}/{num_studies - 1}', '*' * 40) + print('*' * 40, f'Starting study {study_index + 1}/{num_studies}', '*' * 40) print('-' * 100) study_dir = os.path.join(experiment_name, f'study_{study_index}') From 7634a0bc6c63798f3c93a5e7bd143237a97c5925 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:29:22 +0000 Subject: [PATCH 132/155] debug --- docker/scripts/startup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index b4eff52ff..2f1ebb4b7 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -205,7 +205,7 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then fi if [[ ! -z ${HPARAM_START_INDEX+x} ]]; then - HPARAM_START_INDEX_FLAG="--hparam_start_index=${HPARAM_START_INDEX}" + HPARAM_START_INDEX_FLAG="--hparam_start_index=blabla" fi if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then From 235bc69de749670012fd48901ca1c72daf617b11 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:43:46 +0000 Subject: [PATCH 133/155] debugging --- docker/scripts/startup.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 2f1ebb4b7..5e7c74988 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -205,10 +205,11 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then fi if [[ ! -z ${HPARAM_START_INDEX+x} ]]; then - HPARAM_START_INDEX_FLAG="--hparam_start_index=blabla" + HPARAM_START_INDEX_FLAG="--hparam_start_index=${HPARAM_START_INDEX}" fi if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then + echo "SETTING FLAGGGGGG" HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}" fi From a8d04ccd1ada1bcd97be84b23858440430b1aca8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:45:24 +0000 Subject: [PATCH 134/155] debugging --- docker/scripts/startup.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 5e7c74988..914e7d640 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -210,6 +210,7 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then echo "SETTING FLAGGGGGG" + echo ${HPARAM_END_INDEX} HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}" fi From b2571b230c9808bad1089ceb41df53fac4c106d5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:50:30 +0000 Subject: [PATCH 135/155] fix --- scoring/run_workloads.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 72b43dd9f..3dea262d4 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -110,8 +110,10 @@ def main(_): tuning_search_space = FLAGS.tuning_search_space num_studies = FLAGS.num_studies num_tuning_trials = FLAGS.num_tuning_trials - hparam_start_index = FLAGS.hparam_start_index - hparam_end_index = FLAGS.hparam_end_index + if FLAGS.hparam_start_index: + hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} ' + if FLAGS.hparam_end_index: + hparam_end_index_flag = f'--hparam_end_index {FLAGS.hparam_end_index} ' study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 submission_id = FLAGS.submission_id @@ -168,8 +170,8 @@ def main(_): f'-e {study_dir} ' f'-m {max_steps} ' f'--num_tuning_trials {num_tuning_trials} ' - f'--hparam_start_index {hparam_start_index} ' - f'--hparam_end_index {hparam_end_index} ' + f'{hparam_start_index_flag} ' + f'{hparam_end_index_flag} ' f'--rng_seed {run_seed} ' '-c false ' '-o true ' From 18bc3474c158b63b32991306bc06553520d1e84d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:50:55 +0000 Subject: [PATCH 136/155] remove debugging statemetns --- docker/scripts/startup.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 914e7d640..b4eff52ff 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -209,8 +209,6 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then fi if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then - echo "SETTING FLAGGGGGG" - echo ${HPARAM_END_INDEX} HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}" fi From 4a986985d3c535882c64c4163868f008e0d280e2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:59:48 +0000 Subject: [PATCH 137/155] fix --- scoring/run_workloads.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 3dea262d4..82afbfb7a 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -110,6 +110,8 @@ def main(_): tuning_search_space = FLAGS.tuning_search_space num_studies = FLAGS.num_studies num_tuning_trials = FLAGS.num_tuning_trials + hparam_start_index_flag = '' + hparam_end_index_flag = '' if FLAGS.hparam_start_index: hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} ' if FLAGS.hparam_end_index: From 4d38e55e1e2d658765732b1e276137b0472745e1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 19:08:14 +0000 Subject: [PATCH 138/155] formatting --- scoring/run_workloads.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 82afbfb7a..083dafb6a 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -135,8 +135,7 @@ def main(_): FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - rng_subkeys = prng.split(rng_key, - num_studies)[:num_studies:] + rng_subkeys = prng.split(rng_key, num_studies)[:num_studies:] for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) From 4d413f47e155fc8b14d90f02248df90ac2955a0d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Jan 2024 01:28:51 +0000 Subject: [PATCH 139/155] take into account median of studies for scoring --- scoring/performance_profile.py | 47 ++++++++++++-------- scoring/score_submissions.py | 9 ++-- scoring/scoring_utils.py | 78 ++++++++++++++++++---------------- 3 files changed, 78 insertions(+), 56 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index ef4e97f88..9c334ee22 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -50,6 +50,7 @@ NUM_BASE_WORKLOADS = 8 NUM_VARIANT_WORKLOADS = 6 NUM_TRIALS = 5 +NUM_STUDIES = 5 MIN_EVAL_METRICS = [ 'ce_loss', @@ -151,6 +152,7 @@ def get_index_that_reaches_target(workload_df, else: index_reached = target_reached.apply(np.argmax) trial = index_reached.idxmin() + print(trial, index_reached[trial]) return trial, index_reached[trial] @@ -182,27 +184,40 @@ def get_times_for_submission(submission, f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' f'but found {num_workloads} workloads.') for workload, group in submission.groupby('workload'): - num_trials = len(group) - if num_trials != NUM_TRIALS and not self_tuning_ruleset: - if strict: - raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') - else: - logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) - trial_idx, time_idx = get_index_that_reaches_target( - group, validation_metric, validation_target) - if time_idx > -1: - time_val = group[time_col].loc[trial_idx][time_idx] - else: - time_val = float('inf') + time_vals_per_study = [] + num_studies = len(group.groupby('study')) + if num_studies != NUM_STUDIES: + if strict: + raise ValueError(f'Expecting {NUM_STUDIES} trials for workload ' + f'{workload} but found {num_studies} trials.') + else: + logging.warning(f'Expecting {NUM_STUDIES} trials for workload ' + f'{workload} but found {num_studies} trials.') + for study, group in group.groupby('study'): + num_trials = len(group) + if num_trials != NUM_TRIALS and not self_tuning_ruleset: + if strict: + raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') + else: + logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') + + trial_idx, time_idx = get_index_that_reaches_target( + group, validation_metric, validation_target) + if time_idx > -1: + time_val = group[time_col].loc[trial_idx][time_idx] + else: + time_val = float('inf') + time_vals_per_study.append(time_val) + workloads.append({ 'submission': submission_name, 'workload': workload, - time_col: time_val, + time_col: np.median(time_val), }) if verbosity > 0: @@ -215,9 +230,7 @@ def get_times_for_submission(submission, print('Submission did not reach target') df = pd.DataFrame.from_records(workloads) - print(df) df = df.pivot(index='submission', columns='workload', values=time_col) - print(time_col) return df diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 106c6b1da..67e0317ae 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -73,9 +73,12 @@ def print_submission_summary(df): def main(_): results = {} - for submission in os.path.listdir(FLAGS.submission_directory): - df = scoring_utils.get_experiment_df(FLAGS.experiment_path) + for submission in os.listdir(FLAGS.submission_directory): + experiment_path = os.path.join(FLAGS.submission_directory, submission) + df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df + print('SUMMARY ') + print(df.keys()) print_submission_summary(df) if FLAGS.compute_performance_profiles: @@ -100,5 +103,5 @@ def main(_): if __name__ == '__main__': - flags.mark_flag_as_required('experiment_path') + # flags.mark_flag_as_required('submission_directory') app.run(main) diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 8252c75a9..1fff9e6b5 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -150,52 +150,58 @@ def get_experiment_df(experiment_dir): collected together. The directory structure is assumed to be: + experiment_dir - + - + - - eval_measurements.csv + + study + + + + + - eval_measurements.csv Returns: df: DataFrame where indices are trials, columns are - metric names and values are lists. + metric names and values are lists of length num evals. e.g - +----+-----------+-----------------------------+--------------------+--------------------+ - | | workload | trial | validation/accuracy| score | - |----+-----------+-----------------------------+--------------------+--------------------| - | 0 | mnist_jax | (trial_1, ) | [0.0911, 0.0949] | [10.6396, 10.6464] | - +----+-----------+-----------------------------+--------------------+--------------------+ + +----+-----------+--------+----------------------------+--------------------+--------------------+ + | | workload | study |trial | validation/accuracy| score | + |----+-----------+--------+----------------------------+--------------------+--------------------| + | 0 | mnist_jax | 0 |(trial_1, ) | [0.0911, 0.0949] | [10.6396, 10.6464] | + +----+-----------+--------+----------------------------+--------------------+--------------------+ """ df = pd.DataFrame() paths = filter( lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir, glob.glob(f"{experiment_dir}*")) for experiment_dir in list(paths): - workload_dirs = os.listdir(experiment_dir) - for workload in workload_dirs: - data = { - 'workload': workload, - } - trial_dirs = [ - t for t in os.listdir(os.path.join(experiment_dir, workload)) - if re.match(TRIAL_DIR_REGEX, t) - ] - for trial in trial_dirs: - eval_measurements_filepath = os.path.join( - experiment_dir, - workload, - trial, - MEASUREMENTS_FILENAME, - ) - try: - trial_df = pd.read_csv(eval_measurements_filepath) - except FileNotFoundError as e: - logging.info(f'Could not read {eval_measurements_filepath}') - continue - data['trial'] = (trial, experiment_dir) - for column in trial_df.columns: - values = trial_df[column].to_numpy() - data[column] = values - trial_df = pd.DataFrame([data]) - df = pd.concat([df, trial_df], ignore_index=True) + study_dirs = os.listdir(experiment_dir) + for study_dir in study_dirs: + workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir)) + for workload in workload_dirs: + data = { + 'workload': workload, + } + logging.info(os.path.join(experiment_dir, study_dir, workload)) + trial_dirs = [ + t for t in os.listdir(os.path.join(experiment_dir, study_dir, workload)) + if re.match(TRIAL_DIR_REGEX, t) + ] + for trial in trial_dirs: + eval_measurements_filepath = os.path.join( + experiment_dir, + study_dir, + workload, + trial, + MEASUREMENTS_FILENAME, + ) + try: + trial_df = pd.read_csv(eval_measurements_filepath) + except FileNotFoundError as e: + logging.info(f'Could not read {eval_measurements_filepath}') + continue + data['trial'] = (trial, experiment_dir) + data['study'] = study_dir + for column in trial_df.columns: + values = trial_df[column].to_numpy() + data[column] = values + trial_df = pd.DataFrame([data]) + df = pd.concat([df, trial_df], ignore_index=True) return df From 84c87b94261a227c1784f82ce009fe4e1733e9a2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Jan 2024 01:29:37 +0000 Subject: [PATCH 140/155] remove debugging --- scoring/performance_profile.py | 1 - scoring/score_submissions.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 9c334ee22..ba0002d5d 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -152,7 +152,6 @@ def get_index_that_reaches_target(workload_df, else: index_reached = target_reached.apply(np.argmax) trial = index_reached.idxmin() - print(trial, index_reached[trial]) return trial, index_reached[trial] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 67e0317ae..866030c44 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -77,8 +77,6 @@ def main(_): experiment_path = os.path.join(FLAGS.submission_directory, submission) df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df - print('SUMMARY ') - print(df.keys()) print_submission_summary(df) if FLAGS.compute_performance_profiles: From d6e2a36db2391f00889f77bb2a25c10bac0998bf Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Jan 2024 01:49:50 +0000 Subject: [PATCH 141/155] formatting --- scoring/performance_profile.py | 7 +++---- scoring/scoring_utils.py | 3 ++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index ba0002d5d..6dc3f00d8 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -190,16 +190,16 @@ def get_times_for_submission(submission, if num_studies != NUM_STUDIES: if strict: raise ValueError(f'Expecting {NUM_STUDIES} trials for workload ' - f'{workload} but found {num_studies} trials.') + f'{workload} but found {num_studies} trials.') else: logging.warning(f'Expecting {NUM_STUDIES} trials for workload ' - f'{workload} but found {num_studies} trials.') + f'{workload} but found {num_studies} trials.') for study, group in group.groupby('study'): num_trials = len(group) if num_trials != NUM_TRIALS and not self_tuning_ruleset: if strict: raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') + f'{workload} but found {num_trials} trials.') else: logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' f'{workload} but found {num_trials} trials.') @@ -211,7 +211,6 @@ def get_times_for_submission(submission, else: time_val = float('inf') time_vals_per_study.append(time_val) - workloads.append({ 'submission': submission_name, diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 1fff9e6b5..b17b9c5bc 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -179,7 +179,8 @@ def get_experiment_df(experiment_dir): } logging.info(os.path.join(experiment_dir, study_dir, workload)) trial_dirs = [ - t for t in os.listdir(os.path.join(experiment_dir, study_dir, workload)) + t for t in os.listdir( + os.path.join(experiment_dir, study_dir, workload)) if re.match(TRIAL_DIR_REGEX, t) ] for trial in trial_dirs: From f34838a4ac4b1c8f39da35f24c06dc243e89055e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Jan 2024 02:08:28 +0000 Subject: [PATCH 142/155] documentation --- GETTING_STARTED.md | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 96a7b7d6f..eea06ba67 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -336,11 +336,44 @@ docker exec -it /bin/bash ``` ## Score your Submission +To score your submission we will score over all workloads, held-out workloads and studies as described in the rules. +In other words, the total number of runs expected for official scoring is: +- for external ruleset 8 (workloads) + 6 (held-out workloads) x 5 (studies) x 5 +- for internal ruleset 8 (workloads) + 6 (held-out workloads) x 5 (studies) -To produce performance profile and performance table: +You may have the time or compute resources to run all required runs, so our scoring scripts will allow some flexibility. + +### Running workloads +To run workloads for scoring you may specify a "virtual" list of held-out workloads. It is important +to note that the official set of held-out workloads will be sampled by the competition organizers during scoring time. + +An example config for held-out workloads is stored in `scoring/held_workloads_example.json`. +To generate a new sample of held out workloads run: + +```bash +python3 generate_held_out_workloads.py --seed --output_filename +``` + +To run a number of studies and trials over all workload using Docker containers for each run: + +```bash +python scoring/run_workloads.py \ +--framework \ +--experiment_name \ +--docker_image_url \ +--submission_path \ +--tuning_search_space \ +--held_out_workloads_config_path held_out_workloads_example.json \ +--num_studies +--seed +``` + +Note that to run the above script you will need the minimum jax_cpu and pytorch_cpu installations of the algorithmic-efficiency package. + +Finally to get the raw scores and performance profiles of group of submissions or single submission: ```bash -python3 scoring/score_submission.py --experiment_path= --output_dir= +python score_submissions.py --submission_directory --output_dir --compute_performance_profiles ``` We provide the scores and performance profiles for the [paper baseline algorithms](/reference_algorithms/paper_baselines/) in the "Baseline Results" section in [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179). From 84dbb075ca510221fa1f04a7ce0d79cdbff82ae0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 30 Jan 2024 20:29:47 +0000 Subject: [PATCH 143/155] fix --- scoring/performance_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 6dc3f00d8..6b49253d8 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -215,7 +215,7 @@ def get_times_for_submission(submission, workloads.append({ 'submission': submission_name, 'workload': workload, - time_col: np.median(time_val), + time_col: np.median(time_vals_per_study), }) if verbosity > 0: From 6d3b0aec04d8003cd9c37f8ce806edb4bb5e8c44 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 05:01:21 +0000 Subject: [PATCH 144/155] remove indexing for rng_subkeys --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 083dafb6a..af319e67b 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -135,7 +135,7 @@ def main(_): FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - rng_subkeys = prng.split(rng_key, num_studies)[:num_studies:] + rng_subkeys = prng.split(rng_key, num_studies) for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) From 7b23443349eb84d3e7c184b29f224a04666d2491 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 05:37:11 +0000 Subject: [PATCH 145/155] add documentation --- scoring/run_workloads.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index af319e67b..ec8b1f8ab 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -56,7 +56,7 @@ None, 'Start index for tuning trials.') flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') -flags.DEFINE_integer('seed', None, 'Random seed for scoring.') +flags.DEFINE_integer('seed', None, 'Random seed for evaluating a submission.') flags.DEFINE_integer('submission_id', 0, 'Submission ID to generate study and hparam seeds.') @@ -125,7 +125,7 @@ def main(_): rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) - rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), submission_id)) + rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id))) workloads = [w for w in WORKLOADS.keys()] @@ -145,7 +145,7 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: - rng_subkey, run_key = prng.split(rng_subkey) + run_key = prng.fold_in(rng_subkey, hash(workload)) run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() From 4b77dddaa0d221da85532e02820bd9c947f3aa7c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:03:04 +0000 Subject: [PATCH 146/155] fix documentation --- GETTING_STARTED.md | 6 ++++-- scoring/run_workloads.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index eea06ba67..48c23a1a6 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -337,9 +337,11 @@ docker exec -it /bin/bash ## Score your Submission To score your submission we will score over all workloads, held-out workloads and studies as described in the rules. +We will sample 1 held-out workload per dataset for a total of 6 held-out workloads and will use the sampled +held-out workloads in the scoring criteria for the matching base workloads. In other words, the total number of runs expected for official scoring is: -- for external ruleset 8 (workloads) + 6 (held-out workloads) x 5 (studies) x 5 -- for internal ruleset 8 (workloads) + 6 (held-out workloads) x 5 (studies) +- for external ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) x 5 (trials) +- for internal ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) You may have the time or compute resources to run all required runs, so our scoring scripts will allow some flexibility. diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index ec8b1f8ab..53d1aa2ee 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -63,10 +63,10 @@ flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') +flags.DEFINE_string('') FLAGS = flags.FLAGS -DATASETS = ['imagenet', 'fastmri', 'ogbg', 'wmt', 'librispeech', 'criteo1tb'] WORKLOADS = { 'imagenet_resnet': {'max_steps': 186_666, 'dataset': 'imagenet'}, From 2f480096a306db7aa9699c0762b268c41070813f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:13:39 +0000 Subject: [PATCH 147/155] add warning --- scoring/score_submissions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 866030c44..156d1b2f9 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -24,8 +24,9 @@ flags.DEFINE_boolean( 'strict', False, - 'Whether to enforce scoring criteria on variant' - 'performance and on 5-trial median performance') + 'Whether to enforce scoring criteria on variant performance and on' + '5-trial median performance. Note that during official scoring this ' + 'flag will be set to True.') flags.DEFINE_boolean( 'self_tuning_ruleset', False, @@ -78,6 +79,12 @@ def main(_): df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df print_submission_summary(df) + + if not FLAGS.strict: + logging.warning('You are running with strict=False. This will relax ' + 'scoring criteria on the held-out workloads, number of trials and number ' + 'of studies. Your score may not be an accurate representation ' + 'under competition scoring rules. To enforce the criteria set strict=True.') if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( From d39eb245cc54f13e42c676a99ef22c3fdd6c1346 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:16:26 +0000 Subject: [PATCH 148/155] typo --- scoring/performance_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 6b49253d8..d0351390b 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -166,7 +166,7 @@ def get_times_for_submission(submission, Args: submission: A DataFrame containing one row for each trial in each workload for a given submission. - submission_name: Globally unique identified for a submission. + submission_name: Globally unique identifier for a submission. time_col: A string indicating which column to use for time. verbosity: Debug level of information; choice of (1, 2, 3). From aecb37f93bf6c7df05f003acb525eb05d87d6071 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:18:54 +0000 Subject: [PATCH 149/155] fix documentation --- GETTING_STARTED.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 48c23a1a6..3fbb29ba5 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -343,7 +343,7 @@ In other words, the total number of runs expected for official scoring is: - for external ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) x 5 (trials) - for internal ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) -You may have the time or compute resources to run all required runs, so our scoring scripts will allow some flexibility. + ### Running workloads To run workloads for scoring you may specify a "virtual" list of held-out workloads. It is important @@ -372,7 +372,10 @@ python scoring/run_workloads.py \ Note that to run the above script you will need the minimum jax_cpu and pytorch_cpu installations of the algorithmic-efficiency package. -Finally to get the raw scores and performance profiles of group of submissions or single submission: +During submission development, it might be useful to do faster, approximate scoring (e.g. without 5 different s +tudies or when some trials are missing) so the scoring scripts allow some flexibility. To simulate official scoring, +pass the `--strict=True` flag in score_submission.py. To get the raw scores and performance profiles of group of +submissions or single submission: ```bash python score_submissions.py --submission_directory --output_dir --compute_performance_profiles From 5135cc85f73f72e5d416dd35d209be258f2aec59 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:47:22 +0000 Subject: [PATCH 150/155] remove prng import from generate_held_out_workloads.py --- scoring/generate_held_out_workloads.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index c61e637bd..db449cb1f 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -1,14 +1,11 @@ import json import os +import numpy as np import struct from absl import app from absl import flags from absl import logging -import jax -import jax.numpy as jnp - -from algorithmic_efficiency import random_utils as prng flags.DEFINE_integer('held_out_workloads_seed', None, @@ -55,17 +52,14 @@ def main(_): rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) - rng_key = prng.PRNGKey(rng_seed) + rng = np.random.default_rng(rng_seed) sampled_held_out_workloads = [] for k, v in HELD_OUT_WORKLOADS.items(): - rng_key, rng_sub_key = prng.split(rng_key, 2) - p = jnp.array([1 / len(v) for w in v]) - sampled_index = jax.random.categorical(rng_sub_key, p) + sampled_index = rng.integers(len(v)) sampled_held_out_workloads.append(v[sampled_index]) logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") - save_held_out_workloads(sampled_held_out_workloads, output_filename) From 6d4f82e8c81b2a9296b9aa99347ff4e400dbf62e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 00:27:46 +0000 Subject: [PATCH 151/155] fix technical documentation --- DOCUMENTATION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index a25f5b689..62b9cba0f 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -409,7 +409,7 @@ The held-out workloads function similarly to a holdout test set discouraging sub Modifications could, for example, include changing the number of layers or units (drawn from an interval), swapping the activation function (drawn from a set of applicable functions), or using different data augmentations (drawn from a list of possible pre-processing steps). The sample space should be wide enough to discourage submitters from simply trying them all out, but at the same time should be restricted enough to produce realistic workloads with acceptable achievable performances. -In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each fixed workload. +In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each dataset. Our scoring procedure uses the held-out workloads only to penalize submissions that can't handle the introduced modifications (see the [Scoring](#scoring) section for further details). From aaa1014cd156bbb94d4423281ebcd09944d4296c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 01:04:46 +0000 Subject: [PATCH 152/155] formatting --- scoring/run_workloads.py | 42 ++++++++++++++++++++---------------- scoring/score_submissions.py | 12 ++++++----- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 53d1aa2ee..bfd02c476 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -45,9 +45,11 @@ 'prize_qualification_baselines/external_tuning/tuning_search_space.json', 'Path to tuning search space.') flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean('dry_run', - False, - 'Whether or not to actually run the command') +flags.DEFINE_boolean( + 'dry_run', + False, + 'Whether or not to actually run the docker containers. ' + 'If False, simply print the docker run commands. ') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') flags.DEFINE_integer('study_end_index', None, 'End index for studies.') @@ -63,23 +65,20 @@ flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') -flags.DEFINE_string('') +flags.DEFINE_string( + 'workload_meta_data_config_path', + None, + 'Path to config containing dataset and maximum number of steps per workload.' + 'The default values of these are set to the full budgets as determined ' + 'via the target-setting procedure. ' + 'Note that training will be interrupted at either the set maximum number ' + 'of steps or the fixed workload maximum run time, whichever comes first. ' + 'If your algorithm has a smaller per step time than our baselines ' + 'you may want to increase the number of steps per workload.') FLAGS = flags.FLAGS -WORKLOADS = { - 'imagenet_resnet': {'max_steps': 186_666, 'dataset': 'imagenet'}, - 'imagenet_vit': {'max_steps': 186_666, 'dataset': 'imagenet'}, - 'fastmri': {'max_steps': 36_189, 'dataset': 'fastmri'}, - 'ogbg': {'max_steps': 80_000, 'dataset': 'ogbg'}, - 'wmt': {'max_steps': 133_333, 'dataset': 'wmt'}, - 'librispeech_deepspeech': {'max_steps': 48_000, 'dataset': 'librispeech'}, - 'criteo1tb': {'max_steps': 10_666, 'dataset': 'criteo1tb'}, - 'librispeech_conformer': {'max_steps': 80_000, 'dataset': 'librispeech'}, -} - - def read_held_out_workloads(filename): with open(filename, "r") as f: held_out_workloads = json.load(f) @@ -127,7 +126,10 @@ def main(_): logging.info('Using RNG seed %d', rng_seed) rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id))) - workloads = [w for w in WORKLOADS.keys()] + with open(FLAGS.workload_meta_data_config_path) as f: + workload_meta_data = json.load(f) + + workloads = [w for w in workload_meta_data.keys()] # Read held-out workloads if FLAGS.held_out_workloads_config_path: @@ -152,8 +154,9 @@ def main(_): os.system( "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('=' * 100) - dataset = WORKLOADS[base_workload_name]['dataset'] - max_steps = int(WORKLOADS[base_workload_name]['max_steps'] * run_fraction) + dataset = workload_meta_data[base_workload_name]['dataset'] + max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * + run_fraction) mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' @@ -202,5 +205,6 @@ def main(_): if __name__ == '__main__': + flags.mark_flag_as_required('workload_meta_data_config_path') app.run(main) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 156d1b2f9..aafc5530a 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -79,12 +79,14 @@ def main(_): df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df print_submission_summary(df) - + if not FLAGS.strict: - logging.warning('You are running with strict=False. This will relax ' - 'scoring criteria on the held-out workloads, number of trials and number ' - 'of studies. Your score may not be an accurate representation ' - 'under competition scoring rules. To enforce the criteria set strict=True.') + logging.warning( + 'You are running with strict=False. This will relax ' + 'scoring criteria on the held-out workloads, number of trials and number ' + 'of studies. Your score may not be an accurate representation ' + 'under competition scoring rules. To enforce the criteria set strict=True.' + ) if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( From 761a877743ae06df055a7d4d7c8744eb28bb23d0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 01:06:18 +0000 Subject: [PATCH 153/155] add default for workload metadata config file --- scoring/run_workloads.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index bfd02c476..e9f76566f 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -45,11 +45,10 @@ 'prize_qualification_baselines/external_tuning/tuning_search_space.json', 'Path to tuning search space.') flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean( - 'dry_run', - False, - 'Whether or not to actually run the docker containers. ' - 'If False, simply print the docker run commands. ') +flags.DEFINE_boolean('dry_run', + False, + 'Whether or not to actually run the docker containers. ' + 'If False, simply print the docker run commands. ') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') flags.DEFINE_integer('study_end_index', None, 'End index for studies.') @@ -65,16 +64,15 @@ flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') -flags.DEFINE_string( - 'workload_meta_data_config_path', - None, - 'Path to config containing dataset and maximum number of steps per workload.' - 'The default values of these are set to the full budgets as determined ' - 'via the target-setting procedure. ' - 'Note that training will be interrupted at either the set maximum number ' - 'of steps or the fixed workload maximum run time, whichever comes first. ' - 'If your algorithm has a smaller per step time than our baselines ' - 'you may want to increase the number of steps per workload.') +flags.DEFINE_string('workload_meta_data_config_path', + 'workload_meta_data.json', + 'Path to config containing dataset and maximum number of steps per workload.' + 'The default values of these are set to the full budgets as determined ' + 'via the target-setting procedure. ' + 'Note that training will be interrupted at either the set maximum number ' + 'of steps or the fixed workload maximum run time, whichever comes first. ' + 'If your algorithm has a smaller per step time than our baselines ' + 'you may want to increase the number of steps per workload.') FLAGS = flags.FLAGS @@ -155,8 +153,7 @@ def main(_): "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('=' * 100) dataset = workload_meta_data[base_workload_name]['dataset'] - max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * - run_fraction) + max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * run_fraction) mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' @@ -205,6 +202,4 @@ def main(_): if __name__ == '__main__': - flags.mark_flag_as_required('workload_meta_data_config_path') - app.run(main) From 6b3827a4d63030cce79998688d6885170c4a9dab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 01:13:24 +0000 Subject: [PATCH 154/155] yapf fix --- scoring/run_workloads.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index e9f76566f..077ce8d4f 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -45,10 +45,11 @@ 'prize_qualification_baselines/external_tuning/tuning_search_space.json', 'Path to tuning search space.') flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean('dry_run', - False, - 'Whether or not to actually run the docker containers. ' - 'If False, simply print the docker run commands. ') +flags.DEFINE_boolean( + 'dry_run', + False, + 'Whether or not to actually run the docker containers. ' + 'If False, simply print the docker run commands. ') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') flags.DEFINE_integer('study_end_index', None, 'End index for studies.') @@ -64,15 +65,16 @@ flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') -flags.DEFINE_string('workload_meta_data_config_path', - 'workload_meta_data.json', - 'Path to config containing dataset and maximum number of steps per workload.' - 'The default values of these are set to the full budgets as determined ' - 'via the target-setting procedure. ' - 'Note that training will be interrupted at either the set maximum number ' - 'of steps or the fixed workload maximum run time, whichever comes first. ' - 'If your algorithm has a smaller per step time than our baselines ' - 'you may want to increase the number of steps per workload.') +flags.DEFINE_string( + 'workload_meta_data_config_path', + 'workload_meta_data.json', + 'Path to config containing dataset and maximum number of steps per workload.' + 'The default values of these are set to the full budgets as determined ' + 'via the target-setting procedure. ' + 'Note that training will be interrupted at either the set maximum number ' + 'of steps or the fixed workload maximum run time, whichever comes first. ' + 'If your algorithm has a smaller per step time than our baselines ' + 'you may want to increase the number of steps per workload.') FLAGS = flags.FLAGS @@ -153,7 +155,8 @@ def main(_): "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('=' * 100) dataset = workload_meta_data[base_workload_name]['dataset'] - max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * run_fraction) + max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * + run_fraction) mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' From c0e1aad8aca0a7bc91d0673957847ef20f2158bf Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 01:15:50 +0000 Subject: [PATCH 155/155] import order --- scoring/generate_held_out_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index db449cb1f..474c4e7d7 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -1,11 +1,11 @@ import json import os -import numpy as np import struct from absl import app from absl import flags from absl import logging +import numpy as np flags.DEFINE_integer('held_out_workloads_seed', None,