From 3933a340ae77228c88d25aac91fc52a516045b39 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Nov 2023 05:37:07 +0000 Subject: [PATCH 001/169] criteo variants --- .../criteo1tb/criteo1tb_jax/models.py | 102 ++++++++++++++ .../criteo1tb/criteo1tb_jax/workload.py | 43 +++++- .../criteo1tb/criteo1tb_pytorch/models.py | 130 +++++++++++++++++- .../criteo1tb/criteo1tb_pytorch/workload.py | 46 ++++++- 4 files changed, 313 insertions(+), 8 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index d47f1b484..596afef0e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -7,6 +7,101 @@ import jax.numpy as jnp +class DLRMResNet(nn.Module): + """Define a DLRMResNet model. + + Parameters: + vocab_size: the size of a single unified embedding table. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + num_dense_features: number of dense features as the bottom mlp input. + embed_dim: embedding dimension. + """ + + vocab_size: int = 32 * 128 * 1024 # 4_194_304 + num_dense_features: int = 13 + mlp_bottom_dims: Sequence[int] = (256, 256, 256) + mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) + embed_dim: int = 128 + dropout_rate: float = 0.0 + use_layer_norm: bool = False # Unused. + + @nn.compact + def __call__(self, x, train): + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) + cat_features = jnp.asarray(cat_features, dtype=jnp.int32) + + # bottom mlp + mlp_bottom_dims = self.mlp_bottom_dims + + bot_mlp_input = nn.Dense( + mlp_bottom_dims[0], + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_bottom_dims[0])), + )(bot_mlp_input) + bot_mlp_input = nn.relu(bot_mlp_input) + + for dense_dim in mlp_bottom_dims[1:]: + x = nn.Dense( + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), + )(bot_mlp_input) + bot_mlp_input += nn.relu(x) + + base_init_fn = jnn.initializers.uniform(scale=1.0) + # Embedding table init and lookup for a single unified table. + idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + def scaled_init(key, shape, dtype=jnp.float_): + return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) + + embedding_table = self.param( + 'embedding_table', + scaled_init, + [self.vocab_size, self.embed_dim]) + + embed_features = embedding_table[idx_lookup] + batch_size = bot_mlp_input.shape[0] + embed_features = jnp.reshape( + embed_features, (batch_size, 26 * self.embed_dim)) + top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) + mlp_input_dim = top_mlp_input.shape[1] + mlp_top_dims = self.mlp_top_dims + num_layers_top = len(mlp_top_dims) + top_mlp_input = nn.Dense( + mlp_top_dims[0], + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( + top_mlp_input) + top_mlp_input = nn.relu(top_mlp_input) + for layer_idx, fan_out in Sequence(enumerate(mlp_top_dims))[1:-1]: + fan_in = mlp_top_dims[layer_idx - 1] + x = nn.Dense( + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( + top_mlp_input) + x = nn.relu(x) + if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: + x = nn.Dropout( + rate=self.dropout_rate, deterministic=not train)(x) + top_mlp_input += x + # In the DLRM model the last layer width is always 1. We can hardcode that + # below. + logits = nn.Dense( + 1, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0)))(top_mlp_input) + return logits + + def dot_interact(concat_features): """Performs feature interaction operation between dense or sparse features. Input tensors represent dense or sparse features. @@ -52,6 +147,7 @@ class DlrmSmall(nn.Module): mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) embed_dim: int = 128 dropout_rate: float = 0.0 + use_layer_norm: bool = False @nn.compact def __call__(self, x, train): @@ -67,6 +163,8 @@ def __call__(self, x, train): )( bot_mlp_input) bot_mlp_input = nn.relu(bot_mlp_input) + if self.use_layer_norm: + bot_mlp_input = nn.LayerNorm()(bot_mlp_input) bot_mlp_output = bot_mlp_input batch_size = bot_mlp_output.shape[0] feature_stack = jnp.reshape(bot_mlp_output, @@ -86,6 +184,8 @@ def scaled_init(key, shape, dtype=jnp.float_): embed_features = embedding_table[idx_lookup] embed_features = jnp.reshape(embed_features, [batch_size, -1, self.embed_dim]) + if self.use_layer_norm: + embed_features = nn.LayerNorm()(embed_features) feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) dot_interact_output = dot_interact(concat_features=feature_stack) top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], @@ -103,6 +203,8 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input) if layer_idx < (num_layers_top - 1): top_mlp_input = nn.relu(top_mlp_input) + if self.use_layer_norm: + top_mlp_input = nn.LayerNorm()(top_mlp_input) if (self.dropout_rate is not None and self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_input = nn.Dropout( diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index a76a70289..693573064 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -76,13 +76,18 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Only dropout is used.""" del aux_dropout_rate - self._model = models.DlrmSmall( + if self.use_resnet: + model_class = models.DLRMResNet + else: + model_class = models.DlrmSmall + self._model = model_class( vocab_size=self.vocab_size, num_dense_features=self.num_dense_features, mlp_bottom_dims=self.mlp_bottom_dims, mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, - dropout_rate=dropout_rate) + dropout_rate=dropout_rate, + use_layer_norm=self.use_layer_norm) params_rng, dropout_rng = jax.random.split(rng) init_fake_batch_size = 2 @@ -154,3 +159,37 @@ def _eval_batch(self, class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): vocab_size: int = 32 * 128 * 16 + + +class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload): + + @property + def use_layer_norm(self) -> bool: + """Whether or not to use LayerNorm in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.123744 + + @property + def test_target_value(self) -> float: + return 0.126152 + + +class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): + mlp_bottom_dims = (256, 256, 256) + mlp_top_dims = (256, 256, 256, 256, 1) + + @property + def use_resnet(self) -> bool: + """Whether or not to use residual connections in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.124027 + + @property + def test_target_value(self) -> float: + return 0.126468 diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index de6b4d1dd..6d37565b8 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -25,6 +25,122 @@ def forward(self, dense_features, sparse_features): return torch.cat((dense_features, interactions_flat), dim=1) +class DLRMResNet(nn.Module): + """Define a DLRM-ResNet model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), + embed_dim=128, + dropout_rate=0.0, + use_layer_norm=False): # use_layer_norm is unused. + del use_layer_norm + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chucks = 4 + assert vocab_size % num_chucks == 0 + self.embedding_table_chucks = [] + scale = 1.0 / torch.sqrt(self.vocab_size) + for i in range(num_chucks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chucks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + # bottom mlp + self.bot_layers = [] + input_dim = self.num_dense_features + for dense_dim in self.mlp_bottom_dims: + block = [] + block.append(nn.Linear(input_dim, dense_dim)) + block.append(nn.ReLU(inplace=True)) + self.bot_layers.append(nn.Sequential(*block)) + input_dim = dense_dim + for layer in self.bot_layers: + for module in layer.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + # top mlp + # TODO (JB): Write down the formula here instead of the constant. + input_dims = 506 + self.top_layers = [] + num_layers_top = len(self.mlp_top_dims) + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + block = [] + fan_in = (input_dims + self.embed_dim) if layer_idx == 0 \ + else self.mlp_top_dims[layer_idx - 1] + block.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + block.append(nn.ReLU(inplace=True)) + if (dropout_rate is not None and dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + block.append(nn.Dropout(p=dropout_rate)) + self.top_layers.append(nn.Sequential(*block)) + for layer in self.top_layers: + for module in layer.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + + + def forward(self, x): + # Todo (kasimbeg): add residual layer + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, -1, self.embed_dim]) + # Dot product interactions. + concatenated_dense = self.dot_interact( + dense_features=embedded_dense, sparse_features=embedded_sparse) + + # Final MLP. + logits = self.top_mlp(concatenated_dense) + return logits + + class DlrmSmall(nn.Module): """Define a DLRM-Small model. @@ -43,7 +159,8 @@ def __init__(self, mlp_bottom_dims=(512, 256, 128), mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, - dropout_rate=0.0): + dropout_rate=0.0, + use_layer_norm=False): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -72,6 +189,8 @@ def __init__(self, for dense_dim in self.mlp_bottom_dims: bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) bottom_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): @@ -94,10 +213,16 @@ def __init__(self, top_mlp_layers.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): top_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_layers.append(nn.Dropout(p=dropout_rate)) self.top_mlp = nn.Sequential(*top_mlp_layers) + if use_layer_norm: + self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) + else: + self.embed_ln = None for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( @@ -124,7 +249,8 @@ def forward(self, x): embedded_sparse = embedding_table[idx_lookup] embedded_sparse = torch.reshape(embedded_sparse, [batch_size, -1, self.embed_dim]) - + if self.embed_ln: + embedded_sparse = self.embed_ln(embedded_sparse) # Dot product interactions. concatenated_dense = self.dot_interact( dense_features=embedded_dense, sparse_features=embedded_sparse) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 55b68fb2f..c049ac1de 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -10,8 +10,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.models import \ - DlrmSmall +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch import models from algorithmic_efficiency.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload @@ -76,13 +75,18 @@ def init_model_fn( torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False - model = DlrmSmall( + if self.use_resnet: + model_class = models.DLRMResNet + else: + model_class = models.DlrmSmall + model = model_class( vocab_size=self.vocab_size, num_dense_features=self.num_dense_features, mlp_bottom_dims=self.mlp_bottom_dims, mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, - dropout_rate=dropout_rate) + dropout_rate=dropout_rate, + use_layer_norm=self.use_layer_norm) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -238,3 +242,37 @@ def _eval_batch(self, class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): vocab_size: int = 32 * 128 * 16 + + +class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload): + + @property + def use_layer_norm(self) -> bool: + """Whether or not to use LayerNorm in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.123744 + + @property + def test_target_value(self) -> float: + return 0.126152 + + +class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): + mlp_bottom_dims = (256, 256, 256) + mlp_top_dims = (256, 256, 256, 256, 1) + + @property + def use_resnet(self) -> bool: + """Whether or not to use residual connections in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.124027 + + @property + def test_target_value(self) -> float: + return 0.126468 From 5f40758c303df8384686cc0fccd9af5c855bd39f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 10 Nov 2023 05:43:37 +0000 Subject: [PATCH 002/169] add workload variants --- algorithmic_efficiency/workloads/criteo1tb/workload.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 13bd308fb..95e100f72 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -31,6 +31,16 @@ def target_metric_name(self) -> str: def has_reached_validation_target(self, eval_result: Dict[str, float]) -> bool: + def use_layer_norm(self) -> bool: + """Whether or not to use LayerNorm in the model.""" + return False + + @property + def use_resnet(self) -> bool: + """Whether or not to use residual connections in the model.""" + return False + + def has_reached_validation_target(self, eval_result: float) -> bool: return eval_result['validation/loss'] < self.validation_target_value @property From 38821bbafbaca6761d5ef3d6e1ec566ab535b68b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 11 Nov 2023 01:21:56 +0000 Subject: [PATCH 003/169] add criteo variants to workload registry --- algorithmic_efficiency/workloads/criteo1tb/workload.py | 9 +++++---- algorithmic_efficiency/workloads/workloads.py | 7 +++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 95e100f72..83d3d24db 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -29,8 +29,11 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'loss' - def has_reached_validation_target(self, eval_result: Dict[str, - float]) -> bool: + def has_reached_validation_target(self, + eval_result: Dict[str,float]) -> bool: + return eval_result['validation/loss'] < self.validation_target_value + + @property def use_layer_norm(self) -> bool: """Whether or not to use LayerNorm in the model.""" return False @@ -40,8 +43,6 @@ def use_resnet(self) -> bool: """Whether or not to use residual connections in the model.""" return False - def has_reached_validation_target(self, eval_result: float) -> bool: - return eval_result['validation/loss'] < self.validation_target_value @property def validation_target_value(self) -> float: diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 3178e054c..260897d49 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -19,7 +19,14 @@ 'criteo1tb_test': { 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', + 'criteo1tb_layernorm': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' }, + 'criteo1tb_resnet': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' + } 'fastmri': { 'workload_path': 'fastmri/fastmri', 'workload_class_name': 'FastMRIWorkload', From f28ecbcf8ce44b145b0ef43be22a7ab319fbc039 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 11 Nov 2023 01:22:30 +0000 Subject: [PATCH 004/169] workload registry --- algorithmic_efficiency/workloads/workloads.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 260897d49..125984d36 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -19,6 +19,7 @@ 'criteo1tb_test': { 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', + } 'criteo1tb_layernorm': { 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' From 3e44f6997d23bdfc209f29e8b25d7e2840de86cd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 11 Nov 2023 01:53:13 +0000 Subject: [PATCH 005/169] add regression tests --- .github/workflows/regression_tests.yml | 2 +- .../workflows/regression_tests_variants.yml | 75 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/regression_tests_variants.yml diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml index 3a0736fa2..94d155cdc 100644 --- a/.github/workflows/regression_tests.yml +++ b/.github/workflows/regression_tests.yml @@ -180,4 +180,4 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s baselines/adamw/pytorch/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s baselines/adamw/pytorch/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false \ No newline at end of file diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml new file mode 100644 index 000000000..2798d09d0 --- /dev/null +++ b/.github/workflows/regression_tests_variants.yml @@ -0,0 +1,75 @@ +name: Containerized Regression Tests + +on: + pull_request: + branches: + - 'dev' + +jobs: + build_and_push_jax_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=jax + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + build_and_push_pytorch_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=pytorch + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + criteo_layernorm_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + criteo_resnet_jax: + runs-on: self-hosted + needs: build_and_push_jax_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + criteo_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + criteo_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + \ No newline at end of file From a3a827c10c2562ce4f5d99358ce48d93279c1799 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 11 Nov 2023 01:54:56 +0000 Subject: [PATCH 006/169] add empty line --- .github/workflows/regression_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml index 94d155cdc..3a0736fa2 100644 --- a/.github/workflows/regression_tests.yml +++ b/.github/workflows/regression_tests.yml @@ -180,4 +180,4 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s baselines/adamw/pytorch/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false \ No newline at end of file + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s baselines/adamw/pytorch/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false From e1e9b4ef5f4018963becdeb160f0cade8bd4a8e7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 11 Nov 2023 01:56:14 +0000 Subject: [PATCH 007/169] formatting --- .../criteo1tb/criteo1tb_jax/models.py | 24 +++++++++---------- .../workloads/criteo1tb/workload.py | 7 +++--- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 596afef0e..bb7d02879 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -39,7 +39,8 @@ def __call__(self, x, train): kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal( stddev=jnp.sqrt(1.0 / mlp_bottom_dims[0])), - )(bot_mlp_input) + )( + bot_mlp_input) bot_mlp_input = nn.relu(bot_mlp_input) for dense_dim in mlp_bottom_dims[1:]: @@ -47,24 +48,24 @@ def __call__(self, x, train): dense_dim, kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), - )(bot_mlp_input) + )( + bot_mlp_input) bot_mlp_input += nn.relu(x) base_init_fn = jnn.initializers.uniform(scale=1.0) # Embedding table init and lookup for a single unified table. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) - embedding_table = self.param( - 'embedding_table', - scaled_init, - [self.vocab_size, self.embed_dim]) + embedding_table = self.param('embedding_table', + scaled_init, [self.vocab_size, self.embed_dim]) embed_features = embedding_table[idx_lookup] batch_size = bot_mlp_input.shape[0] - embed_features = jnp.reshape( - embed_features, (batch_size, 26 * self.embed_dim)) + embed_features = jnp.reshape(embed_features, + (batch_size, 26 * self.embed_dim)) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims @@ -88,8 +89,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input) x = nn.relu(x) if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: - x = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)(x) + x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -97,8 +97,8 @@ def scaled_init(key, shape, dtype=jnp.float_): 1, kernel_init=jnn.initializers.normal( stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), - bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0)))(top_mlp_input) + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))( + top_mlp_input) return logits diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 83d3d24db..4b2dcbf19 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -29,11 +29,11 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'loss' - def has_reached_validation_target(self, - eval_result: Dict[str,float]) -> bool: + def has_reached_validation_target(self, eval_result: Dict[str, + float]) -> bool: return eval_result['validation/loss'] < self.validation_target_value - @property + @property def use_layer_norm(self) -> bool: """Whether or not to use LayerNorm in the model.""" return False @@ -43,7 +43,6 @@ def use_resnet(self) -> bool: """Whether or not to use residual connections in the model.""" return False - @property def validation_target_value(self) -> float: return 0.123735 From d3334f5e683ba890b6d8cdc026af19f32a428fa6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 13 Nov 2023 20:25:26 +0000 Subject: [PATCH 008/169] regression test fix --- .github/workflows/regression_tests_variants.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index 2798d09d0..11957712e 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -54,7 +54,7 @@ jobs: run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_pytorch: + criteo_layernorm_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image steps: @@ -63,7 +63,7 @@ jobs: run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_pytorch: + criteo_resnet_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image steps: From b6443f28e4909cedc3fe4e9f4ebe02570636983d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 13 Nov 2023 21:46:05 +0000 Subject: [PATCH 009/169] add criteo_variants to valid workloads to startup.sh" --- docker/scripts/startup.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 410d21532..c6ac2d701 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -23,7 +23,7 @@ function usage() { -t | --tuning_search_space: Path to tuning search space. If relative path, from algorithmic-efficiency top directory. -e | --experiment_name: Name of experiment. -w | --workload: Can be imagenet_resnet, imagenet_vit, criteo1tb, fastmri, - wmt, librispeech_deepspeech, librispeech_conformer. + wmt, librispeech_deepspeech, librispeech_conformer or variant workload. -a | --keep_container_alive: If true, docker container will be kept alive. Useful for developing or debugging. -m | --max_global_steps: Maximum number of global steps for submission. @@ -113,7 +113,8 @@ done VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ "wmt" "mnist") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_vit" "fastmri" "ogbg" \ - "wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist") + "wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist" \ + "criteo1tb_resnet" "criteo1tb_layernorm") # Set data and experiment paths From 480542387c3acb5af8d92f74ccc23524fd0bc34b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 03:06:34 +0000 Subject: [PATCH 010/169] syntax fix --- algorithmic_efficiency/workloads/workloads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 125984d36..d86db0643 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -19,7 +19,7 @@ 'criteo1tb_test': { 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload', - } + }, 'criteo1tb_layernorm': { 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' @@ -27,7 +27,7 @@ 'criteo1tb_resnet': { 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' - } + }, 'fastmri': { 'workload_path': 'fastmri/fastmri', 'workload_class_name': 'FastMRIWorkload', From b5c197946efd4b1298efc25e85dc18928ddc4b4c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 20:31:58 +0000 Subject: [PATCH 011/169] add bsz for criteo1tb variants --- baselines/adafactor/jax/submission.py | 4 ++++ baselines/adafactor/pytorch/submission.py | 4 ++++ baselines/adamw/jax/submission.py | 4 ++++ baselines/adamw/pytorch/submission.py | 4 ++++ baselines/lamb/jax/submission.py | 4 ++++ baselines/lamb/pytorch/submission.py | 4 ++++ baselines/momentum/jax/submission.py | 4 ++++ baselines/momentum/pytorch/submission.py | 4 ++++ baselines/nadamw/jax/submission.py | 4 ++++ baselines/nadamw/pytorch/submission.py | 4 ++++ baselines/nesterov/jax/submission.py | 4 ++++ baselines/nesterov/pytorch/submission.py | 4 ++++ baselines/sam/jax/submission.py | 4 ++++ baselines/sam/pytorch/submission.py | 4 ++++ baselines/shampoo/jax/submission.py | 4 ++++ 15 files changed, 60 insertions(+) diff --git a/baselines/adafactor/jax/submission.py b/baselines/adafactor/jax/submission.py index ec8020e7e..e9beb2aea 100644 --- a/baselines/adafactor/jax/submission.py +++ b/baselines/adafactor/jax/submission.py @@ -160,6 +160,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/adafactor/pytorch/submission.py b/baselines/adafactor/pytorch/submission.py index e6fef17dc..9cc5af911 100644 --- a/baselines/adafactor/pytorch/submission.py +++ b/baselines/adafactor/pytorch/submission.py @@ -269,6 +269,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/adamw/jax/submission.py b/baselines/adamw/jax/submission.py index 11212c1a0..54679c5ee 100644 --- a/baselines/adamw/jax/submission.py +++ b/baselines/adamw/jax/submission.py @@ -161,6 +161,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/adamw/pytorch/submission.py b/baselines/adamw/pytorch/submission.py index 75a4abbef..3fe3ba74c 100644 --- a/baselines/adamw/pytorch/submission.py +++ b/baselines/adamw/pytorch/submission.py @@ -129,6 +129,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/lamb/jax/submission.py b/baselines/lamb/jax/submission.py index 27d635ee9..24c7f978b 100644 --- a/baselines/lamb/jax/submission.py +++ b/baselines/lamb/jax/submission.py @@ -169,6 +169,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/lamb/pytorch/submission.py b/baselines/lamb/pytorch/submission.py index 7d0d8763e..60483cb98 100644 --- a/baselines/lamb/pytorch/submission.py +++ b/baselines/lamb/pytorch/submission.py @@ -262,6 +262,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/momentum/jax/submission.py b/baselines/momentum/jax/submission.py index 4139ebcf6..d24b977d0 100644 --- a/baselines/momentum/jax/submission.py +++ b/baselines/momentum/jax/submission.py @@ -195,6 +195,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/momentum/pytorch/submission.py b/baselines/momentum/pytorch/submission.py index b7d87924d..29bca94e9 100644 --- a/baselines/momentum/pytorch/submission.py +++ b/baselines/momentum/pytorch/submission.py @@ -148,6 +148,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/nadamw/jax/submission.py b/baselines/nadamw/jax/submission.py index 099613fcf..bbfbfcee4 100644 --- a/baselines/nadamw/jax/submission.py +++ b/baselines/nadamw/jax/submission.py @@ -303,6 +303,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index 01cffc52e..6ba7c2354 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -305,6 +305,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/nesterov/jax/submission.py b/baselines/nesterov/jax/submission.py index 35cebba1f..202c3abe8 100644 --- a/baselines/nesterov/jax/submission.py +++ b/baselines/nesterov/jax/submission.py @@ -195,6 +195,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/nesterov/pytorch/submission.py b/baselines/nesterov/pytorch/submission.py index 45feb8645..2af8072f8 100644 --- a/baselines/nesterov/pytorch/submission.py +++ b/baselines/nesterov/pytorch/submission.py @@ -148,6 +148,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/sam/jax/submission.py b/baselines/sam/jax/submission.py index 85b3d7441..749577e9d 100644 --- a/baselines/sam/jax/submission.py +++ b/baselines/sam/jax/submission.py @@ -248,6 +248,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/sam/pytorch/submission.py b/baselines/sam/pytorch/submission.py index 2cab75972..1828986bd 100644 --- a/baselines/sam/pytorch/submission.py +++ b/baselines/sam/pytorch/submission.py @@ -220,6 +220,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/shampoo/jax/submission.py b/baselines/shampoo/jax/submission.py index cb062faf3..d4193c592 100644 --- a/baselines/shampoo/jax/submission.py +++ b/baselines/shampoo/jax/submission.py @@ -163,6 +163,10 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 + elif workoad_name == 'criteo1tb_layernorm' + return 262_144 + elif workload_name == 'criteo1tb_resnet' + return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': From 87a04250decbc96ddfc805468e1111c2d04aafc0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 20:33:13 +0000 Subject: [PATCH 012/169] change regression test name for variants --- .github/workflows/regression_tests_variants.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index 11957712e..13a21f847 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -1,4 +1,4 @@ -name: Containerized Regression Tests +name: Containerized Regression Tests for Workload Variants on: pull_request: From 6e6e7868ae6bdb3a58f7aca89e5454b93617c844 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 20:42:14 +0000 Subject: [PATCH 013/169] syntax fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 6d37565b8..6e1688206 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -115,8 +115,8 @@ def __init__(self, - def forward(self, x): - # Todo (kasimbeg): add residual layer + def forward(self, x): + # Todo (kasimbeg): add residual layer batch_size = x.shape[0] dense_features, sparse_features = torch.split( From dc208d01abce08ebb99372ab5fd75c96178ab1aa Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 20:44:01 +0000 Subject: [PATCH 014/169] lint fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 6e1688206..3a055f5b6 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -116,7 +116,7 @@ def __init__(self, def forward(self, x): - # Todo (kasimbeg): add residual layer + # Todo (kasimbeg): add residual layer batch_size = x.shape[0] dense_features, sparse_features = torch.split( @@ -138,7 +138,7 @@ def forward(self, x): # Final MLP. logits = self.top_mlp(concatenated_dense) - return logits + return logits class DlrmSmall(nn.Module): From 2ee0f77a1bb126f60a965fa35e8c2b2a7fee49cb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 21:09:40 +0000 Subject: [PATCH 015/169] formatting --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 3a055f5b6..72cc22f29 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -113,8 +113,6 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - - def forward(self, x): # Todo (kasimbeg): add residual layer batch_size = x.shape[0] From fa756636e9a8361609e41e218dd28db3f9c791e7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 22:29:04 +0000 Subject: [PATCH 016/169] fix --- baselines/adafactor/jax/submission.py | 4 ++-- baselines/adafactor/pytorch/submission.py | 4 ++-- baselines/adamw/jax/submission.py | 4 ++-- baselines/adamw/pytorch/submission.py | 4 ++-- baselines/lamb/jax/submission.py | 4 ++-- baselines/lamb/pytorch/submission.py | 4 ++-- baselines/momentum/jax/submission.py | 4 ++-- baselines/momentum/pytorch/submission.py | 4 ++-- baselines/nadamw/jax/submission.py | 4 ++-- baselines/nadamw/pytorch/submission.py | 4 ++-- baselines/nesterov/jax/submission.py | 4 ++-- baselines/nesterov/pytorch/submission.py | 4 ++-- baselines/sam/jax/submission.py | 4 ++-- baselines/sam/pytorch/submission.py | 4 ++-- baselines/shampoo/jax/submission.py | 4 ++-- 15 files changed, 30 insertions(+), 30 deletions(-) diff --git a/baselines/adafactor/jax/submission.py b/baselines/adafactor/jax/submission.py index e9beb2aea..74d3b8df6 100644 --- a/baselines/adafactor/jax/submission.py +++ b/baselines/adafactor/jax/submission.py @@ -160,9 +160,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/adafactor/pytorch/submission.py b/baselines/adafactor/pytorch/submission.py index 9cc5af911..fdab24305 100644 --- a/baselines/adafactor/pytorch/submission.py +++ b/baselines/adafactor/pytorch/submission.py @@ -269,9 +269,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/adamw/jax/submission.py b/baselines/adamw/jax/submission.py index 54679c5ee..aaa69d78d 100644 --- a/baselines/adamw/jax/submission.py +++ b/baselines/adamw/jax/submission.py @@ -161,9 +161,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/adamw/pytorch/submission.py b/baselines/adamw/pytorch/submission.py index 3fe3ba74c..3be5de6fe 100644 --- a/baselines/adamw/pytorch/submission.py +++ b/baselines/adamw/pytorch/submission.py @@ -129,9 +129,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/lamb/jax/submission.py b/baselines/lamb/jax/submission.py index 24c7f978b..3ef80dc05 100644 --- a/baselines/lamb/jax/submission.py +++ b/baselines/lamb/jax/submission.py @@ -169,9 +169,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/lamb/pytorch/submission.py b/baselines/lamb/pytorch/submission.py index 60483cb98..ca3056f44 100644 --- a/baselines/lamb/pytorch/submission.py +++ b/baselines/lamb/pytorch/submission.py @@ -262,9 +262,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/momentum/jax/submission.py b/baselines/momentum/jax/submission.py index d24b977d0..2ab18e977 100644 --- a/baselines/momentum/jax/submission.py +++ b/baselines/momentum/jax/submission.py @@ -195,9 +195,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/momentum/pytorch/submission.py b/baselines/momentum/pytorch/submission.py index 29bca94e9..31c35bafc 100644 --- a/baselines/momentum/pytorch/submission.py +++ b/baselines/momentum/pytorch/submission.py @@ -148,9 +148,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/nadamw/jax/submission.py b/baselines/nadamw/jax/submission.py index bbfbfcee4..19f7491d5 100644 --- a/baselines/nadamw/jax/submission.py +++ b/baselines/nadamw/jax/submission.py @@ -303,9 +303,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index 6ba7c2354..0f6cc7d9c 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -305,9 +305,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/nesterov/jax/submission.py b/baselines/nesterov/jax/submission.py index 202c3abe8..5c8d5d2e7 100644 --- a/baselines/nesterov/jax/submission.py +++ b/baselines/nesterov/jax/submission.py @@ -195,9 +195,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/nesterov/pytorch/submission.py b/baselines/nesterov/pytorch/submission.py index 2af8072f8..1cc41b7ea 100644 --- a/baselines/nesterov/pytorch/submission.py +++ b/baselines/nesterov/pytorch/submission.py @@ -148,9 +148,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/sam/jax/submission.py b/baselines/sam/jax/submission.py index 749577e9d..132894c15 100644 --- a/baselines/sam/jax/submission.py +++ b/baselines/sam/jax/submission.py @@ -248,9 +248,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/sam/pytorch/submission.py b/baselines/sam/pytorch/submission.py index 1828986bd..3c1416ffa 100644 --- a/baselines/sam/pytorch/submission.py +++ b/baselines/sam/pytorch/submission.py @@ -220,9 +220,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 diff --git a/baselines/shampoo/jax/submission.py b/baselines/shampoo/jax/submission.py index d4193c592..74465b3eb 100644 --- a/baselines/shampoo/jax/submission.py +++ b/baselines/shampoo/jax/submission.py @@ -163,9 +163,9 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm' + elif workoad_name == 'criteo1tb_layernorm': return 262_144 - elif workload_name == 'criteo1tb_resnet' + elif workload_name == 'criteo1tb_resnet': return 262_144 elif workload_name == 'fastmri': return 32 From d21f05b70c972bb559e7777de7ea9dc1e25007d1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 14 Nov 2023 23:46:33 +0000 Subject: [PATCH 017/169] fix typo --- baselines/adafactor/jax/submission.py | 2 +- baselines/adafactor/pytorch/submission.py | 2 +- baselines/adamw/jax/submission.py | 2 +- baselines/adamw/pytorch/submission.py | 2 +- baselines/lamb/jax/submission.py | 2 +- baselines/lamb/pytorch/submission.py | 2 +- baselines/momentum/jax/submission.py | 2 +- baselines/momentum/pytorch/submission.py | 2 +- baselines/nadamw/jax/submission.py | 2 +- baselines/nadamw/pytorch/submission.py | 2 +- baselines/nesterov/jax/submission.py | 2 +- baselines/nesterov/pytorch/submission.py | 2 +- baselines/sam/jax/submission.py | 2 +- baselines/sam/pytorch/submission.py | 2 +- baselines/shampoo/jax/submission.py | 2 +- 15 files changed, 15 insertions(+), 15 deletions(-) diff --git a/baselines/adafactor/jax/submission.py b/baselines/adafactor/jax/submission.py index 74d3b8df6..81cac0c35 100644 --- a/baselines/adafactor/jax/submission.py +++ b/baselines/adafactor/jax/submission.py @@ -160,7 +160,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/adafactor/pytorch/submission.py b/baselines/adafactor/pytorch/submission.py index fdab24305..11d6cb36f 100644 --- a/baselines/adafactor/pytorch/submission.py +++ b/baselines/adafactor/pytorch/submission.py @@ -269,7 +269,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/adamw/jax/submission.py b/baselines/adamw/jax/submission.py index aaa69d78d..34e769a9a 100644 --- a/baselines/adamw/jax/submission.py +++ b/baselines/adamw/jax/submission.py @@ -161,7 +161,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/adamw/pytorch/submission.py b/baselines/adamw/pytorch/submission.py index 3be5de6fe..d3812899a 100644 --- a/baselines/adamw/pytorch/submission.py +++ b/baselines/adamw/pytorch/submission.py @@ -129,7 +129,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/lamb/jax/submission.py b/baselines/lamb/jax/submission.py index 3ef80dc05..d293c6d7f 100644 --- a/baselines/lamb/jax/submission.py +++ b/baselines/lamb/jax/submission.py @@ -169,7 +169,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/lamb/pytorch/submission.py b/baselines/lamb/pytorch/submission.py index ca3056f44..144d61696 100644 --- a/baselines/lamb/pytorch/submission.py +++ b/baselines/lamb/pytorch/submission.py @@ -262,7 +262,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/momentum/jax/submission.py b/baselines/momentum/jax/submission.py index 2ab18e977..191765f7b 100644 --- a/baselines/momentum/jax/submission.py +++ b/baselines/momentum/jax/submission.py @@ -195,7 +195,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/momentum/pytorch/submission.py b/baselines/momentum/pytorch/submission.py index 31c35bafc..e12e7b1fd 100644 --- a/baselines/momentum/pytorch/submission.py +++ b/baselines/momentum/pytorch/submission.py @@ -148,7 +148,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/nadamw/jax/submission.py b/baselines/nadamw/jax/submission.py index 19f7491d5..5dd7794b3 100644 --- a/baselines/nadamw/jax/submission.py +++ b/baselines/nadamw/jax/submission.py @@ -303,7 +303,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index 0f6cc7d9c..458395067 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -305,7 +305,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/nesterov/jax/submission.py b/baselines/nesterov/jax/submission.py index 5c8d5d2e7..eeb3940e1 100644 --- a/baselines/nesterov/jax/submission.py +++ b/baselines/nesterov/jax/submission.py @@ -195,7 +195,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/nesterov/pytorch/submission.py b/baselines/nesterov/pytorch/submission.py index 1cc41b7ea..7031e9888 100644 --- a/baselines/nesterov/pytorch/submission.py +++ b/baselines/nesterov/pytorch/submission.py @@ -148,7 +148,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/sam/jax/submission.py b/baselines/sam/jax/submission.py index 132894c15..bcf9362b6 100644 --- a/baselines/sam/jax/submission.py +++ b/baselines/sam/jax/submission.py @@ -248,7 +248,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/sam/pytorch/submission.py b/baselines/sam/pytorch/submission.py index 3c1416ffa..6c2d553e8 100644 --- a/baselines/sam/pytorch/submission.py +++ b/baselines/sam/pytorch/submission.py @@ -220,7 +220,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 diff --git a/baselines/shampoo/jax/submission.py b/baselines/shampoo/jax/submission.py index 74465b3eb..c873c98e1 100644 --- a/baselines/shampoo/jax/submission.py +++ b/baselines/shampoo/jax/submission.py @@ -163,7 +163,7 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workoad_name == 'criteo1tb_layernorm': + elif workload_name == 'criteo1tb_layernorm': return 262_144 elif workload_name == 'criteo1tb_resnet': return 262_144 From 8776748a900d88371437fe4788b66dd29ff4adaf Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 15 Nov 2023 01:16:01 +0000 Subject: [PATCH 018/169] fix resnet --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index bb7d02879..316401649 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -78,7 +78,7 @@ def scaled_init(key, shape, dtype=jnp.float_): stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( top_mlp_input) top_mlp_input = nn.relu(top_mlp_input) - for layer_idx, fan_out in Sequence(enumerate(mlp_top_dims))[1:-1]: + for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: fan_in = mlp_top_dims[layer_idx - 1] x = nn.Dense( fan_out, From 6c8fbe963440c45bef847337f74f94aab26b92d6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 15 Nov 2023 01:18:33 +0000 Subject: [PATCH 019/169] modify regresion test --- .../workflows/regression_tests_variants.yml | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index 13a21f847..41d44cc26 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -6,39 +6,39 @@ on: - 'dev' jobs: - build_and_push_jax_docker_image: - runs-on: self-hosted - steps: - - uses: actions/checkout@v2 - - name: Build and push docker images - run: | - GIT_BRANCH=${{ github.head_ref || github.ref_name }} - FRAMEWORK=jax - IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - cd $HOME/algorithmic-efficiency/docker - docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - BUILD_RETURN=$? - if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - build_and_push_pytorch_docker_image: - runs-on: self-hosted - steps: - - uses: actions/checkout@v2 - - name: Build and push docker images - run: | - GIT_BRANCH=${{ github.head_ref || github.ref_name }} - FRAMEWORK=pytorch - IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - cd $HOME/algorithmic-efficiency/docker - docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - BUILD_RETURN=$? - if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + # build_and_push_jax_docker_image: + # runs-on: self-hosted + # steps: + # - uses: actions/checkout@v2 + # - name: Build and push docker images + # run: | + # GIT_BRANCH=${{ github.head_ref || github.ref_name }} + # FRAMEWORK=jax + # IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + # cd $HOME/algorithmic-efficiency/docker + # docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + # BUILD_RETURN=$? + # if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + # docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + # docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + # build_and_push_pytorch_docker_image: + # runs-on: self-hosted + # steps: + # - uses: actions/checkout@v2 + # - name: Build and push docker images + # run: | + # GIT_BRANCH=${{ github.head_ref || github.ref_name }} + # FRAMEWORK=pytorch + # IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + # cd $HOME/algorithmic-efficiency/docker + # docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + # BUILD_RETURN=$? + # if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + # docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + # docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME criteo_layernorm_jax: runs-on: self-hosted - needs: build_and_push_jax_docker_image + # needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload @@ -47,7 +47,7 @@ jobs: docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_jax: runs-on: self-hosted - needs: build_and_push_jax_docker_image + # needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload @@ -56,7 +56,7 @@ jobs: docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_layernorm_pytorch: runs-on: self-hosted - needs: build_and_push_pytorch_docker_image + # needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload @@ -65,7 +65,7 @@ jobs: docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_pytorch: runs-on: self-hosted - needs: build_and_push_pytorch_docker_image + # needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload From 7b7e7502442f3cf1d332f5a6cf1a4b1a7ddb7b17 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 15 Nov 2023 01:26:12 +0000 Subject: [PATCH 020/169] modify_test --- .github/workflows/regression_tests_variants.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index 41d44cc26..6b7f14b71 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -44,7 +44,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_jax: runs-on: self-hosted # needs: build_and_push_jax_docker_image @@ -53,7 +53,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_layernorm_pytorch: runs-on: self-hosted # needs: build_and_push_pytorch_docker_image @@ -62,7 +62,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_pytorch: runs-on: self-hosted # needs: build_and_push_pytorch_docker_image @@ -71,5 +71,5 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false \ No newline at end of file From e601db458f256b379929b5439b63b2d60d11dedb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 01:28:25 +0000 Subject: [PATCH 021/169] add tests --- .../criteo1tb_layernorm/__init__.py | 0 .../modeldiffs/criteo1tb_layernorm/compare.py | 84 +++++++++++++++++++ tests/modeldiffs/criteo1tb_resnet/__init__.py | 0 tests/modeldiffs/criteo1tb_resnet/compare.py | 84 +++++++++++++++++++ 4 files changed, 168 insertions(+) create mode 100644 tests/modeldiffs/criteo1tb_layernorm/__init__.py create mode 100644 tests/modeldiffs/criteo1tb_layernorm/compare.py create mode 100644 tests/modeldiffs/criteo1tb_resnet/__init__.py create mode 100644 tests/modeldiffs/criteo1tb_resnet/compare.py diff --git a/tests/modeldiffs/criteo1tb_layernorm/__init__.py b/tests/modeldiffs/criteo1tb_layernorm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py new file mode 100644 index 000000000..d46d57d15 --- /dev/null +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -0,0 +1,84 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ + Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ + Criteo1TbDlrmSmallLayerNormWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + s_count = None + for i in k: + if 'Sequential' in i: + s_count = int(i.split('_')[1]) + continue + if 'Embedding' in i: + return ('embedding_table',) + if 'Linear' in i: + i = i.replace('Linear', 'Dense') + name, count = i.split('_') + i = name + '_' + str(s_count * 3 + int(count)) + elif 'weight' in i: + i = i.replace('weight', 'kernel') + + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + out = {} + chunks = [] + for k in sd: + if 'embedding_chunk' in ''.join(k): + chunks.append(sd[k].cpu()) + else: + out[k] = sd[k] + out[('embedding_table',)] = torch.cat(chunks, dim=0) + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = { + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), + } + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/criteo1tb_resnet/__init__.py b/tests/modeldiffs/criteo1tb_resnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py new file mode 100644 index 000000000..761cc47bc --- /dev/null +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -0,0 +1,84 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ + Criteo1TbDlrmSmallResNetWorkload as JaxWorkload +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ + Criteo1TbDlrmSmallResNetWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + s_count = None + for i in k: + if 'Sequential' in i: + s_count = int(i.split('_')[1]) + continue + if 'Embedding' in i: + return ('embedding_table',) + if 'Linear' in i: + i = i.replace('Linear', 'Dense') + name, count = i.split('_') + i = name + '_' + str(s_count * 3 + int(count)) + elif 'weight' in i: + i = i.replace('weight', 'kernel') + + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + out = {} + chunks = [] + for k in sd: + if 'embedding_chunk' in ''.join(k): + chunks.append(sd[k].cpu()) + else: + out[k] = sd[k] + out[('embedding_table',)] = torch.cat(chunks, dim=0) + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = { + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), + } + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) From 8d2e828d1c480920126bb03a94fe2c301fb3ff72 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 06:24:46 +0000 Subject: [PATCH 022/169] regression tests --- .../workflows/regression_tests_variants.yml | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index 6b7f14b71..22f49b523 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -3,42 +3,42 @@ name: Containerized Regression Tests for Workload Variants on: pull_request: branches: - - 'dev' + - 'main' jobs: - # build_and_push_jax_docker_image: - # runs-on: self-hosted - # steps: - # - uses: actions/checkout@v2 - # - name: Build and push docker images - # run: | - # GIT_BRANCH=${{ github.head_ref || github.ref_name }} - # FRAMEWORK=jax - # IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - # cd $HOME/algorithmic-efficiency/docker - # docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - # BUILD_RETURN=$? - # if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - # docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - # docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - # build_and_push_pytorch_docker_image: - # runs-on: self-hosted - # steps: - # - uses: actions/checkout@v2 - # - name: Build and push docker images - # run: | - # GIT_BRANCH=${{ github.head_ref || github.ref_name }} - # FRAMEWORK=pytorch - # IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - # cd $HOME/algorithmic-efficiency/docker - # docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - # BUILD_RETURN=$? - # if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - # docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - # docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + build_and_push_jax_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=jax + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + build_and_push_pytorch_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=pytorch + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME criteo_layernorm_jax: runs-on: self-hosted - # needs: build_and_push_jax_docker_image + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload @@ -47,7 +47,7 @@ jobs: docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_jax: runs-on: self-hosted - # needs: build_and_push_jax_docker_image + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload @@ -56,7 +56,7 @@ jobs: docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_layernorm_pytorch: runs-on: self-hosted - # needs: build_and_push_pytorch_docker_image + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload @@ -65,7 +65,7 @@ jobs: docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_pytorch: runs-on: self-hosted - # needs: build_and_push_pytorch_docker_image + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload From f5170c5ceffe02c306301bb7c3a65e26a7dfa3cf Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 07:26:17 +0000 Subject: [PATCH 023/169] extract fix --- datasets/dataset_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2f808b64b..f9ee2f138 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -388,7 +388,7 @@ def download_fastmri(data_dir, def extract(source, dest, mode='r:xz'): if not os.path.exists(dest): - os.path.makedirs(dest) + os.makedirs(dest) logging.info(f'Extracting {source} to {dest}') tar = tarfile.open(source, mode) logging.info('Opened tar') From 21938308161259083003036d6d48b96ed8bf57f6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 20:10:25 +0000 Subject: [PATCH 024/169] remove variant bsz --- .github/workflows/regression_tests_variants.yml | 9 +++++---- baselines/adafactor/jax/submission.py | 4 ---- baselines/adafactor/pytorch/submission.py | 4 ---- baselines/adamw/jax/submission.py | 4 ---- baselines/adamw/pytorch/submission.py | 4 ---- baselines/lamb/jax/submission.py | 4 ---- baselines/lamb/pytorch/submission.py | 4 ---- baselines/momentum/jax/submission.py | 4 ---- baselines/momentum/pytorch/submission.py | 4 ---- baselines/nadamw/jax/submission.py | 4 ---- baselines/nadamw/pytorch/submission.py | 4 ---- baselines/nesterov/jax/submission.py | 4 ---- baselines/nesterov/pytorch/submission.py | 4 ---- baselines/sam/jax/submission.py | 4 ---- baselines/sam/pytorch/submission.py | 4 ---- baselines/shampoo/jax/submission.py | 4 ---- 16 files changed, 5 insertions(+), 64 deletions(-) diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index 22f49b523..d581c80fc 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -44,7 +44,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -53,7 +53,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_layernorm_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -62,7 +62,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -71,5 +71,6 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + \ No newline at end of file diff --git a/baselines/adafactor/jax/submission.py b/baselines/adafactor/jax/submission.py index 81cac0c35..ec8020e7e 100644 --- a/baselines/adafactor/jax/submission.py +++ b/baselines/adafactor/jax/submission.py @@ -160,10 +160,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/adafactor/pytorch/submission.py b/baselines/adafactor/pytorch/submission.py index 11d6cb36f..e6fef17dc 100644 --- a/baselines/adafactor/pytorch/submission.py +++ b/baselines/adafactor/pytorch/submission.py @@ -269,10 +269,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/adamw/jax/submission.py b/baselines/adamw/jax/submission.py index 34e769a9a..11212c1a0 100644 --- a/baselines/adamw/jax/submission.py +++ b/baselines/adamw/jax/submission.py @@ -161,10 +161,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/adamw/pytorch/submission.py b/baselines/adamw/pytorch/submission.py index d3812899a..75a4abbef 100644 --- a/baselines/adamw/pytorch/submission.py +++ b/baselines/adamw/pytorch/submission.py @@ -129,10 +129,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/lamb/jax/submission.py b/baselines/lamb/jax/submission.py index d293c6d7f..27d635ee9 100644 --- a/baselines/lamb/jax/submission.py +++ b/baselines/lamb/jax/submission.py @@ -169,10 +169,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/lamb/pytorch/submission.py b/baselines/lamb/pytorch/submission.py index 144d61696..7d0d8763e 100644 --- a/baselines/lamb/pytorch/submission.py +++ b/baselines/lamb/pytorch/submission.py @@ -262,10 +262,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/momentum/jax/submission.py b/baselines/momentum/jax/submission.py index 191765f7b..4139ebcf6 100644 --- a/baselines/momentum/jax/submission.py +++ b/baselines/momentum/jax/submission.py @@ -195,10 +195,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/momentum/pytorch/submission.py b/baselines/momentum/pytorch/submission.py index e12e7b1fd..b7d87924d 100644 --- a/baselines/momentum/pytorch/submission.py +++ b/baselines/momentum/pytorch/submission.py @@ -148,10 +148,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/nadamw/jax/submission.py b/baselines/nadamw/jax/submission.py index 5dd7794b3..099613fcf 100644 --- a/baselines/nadamw/jax/submission.py +++ b/baselines/nadamw/jax/submission.py @@ -303,10 +303,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/nadamw/pytorch/submission.py b/baselines/nadamw/pytorch/submission.py index 458395067..01cffc52e 100644 --- a/baselines/nadamw/pytorch/submission.py +++ b/baselines/nadamw/pytorch/submission.py @@ -305,10 +305,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/nesterov/jax/submission.py b/baselines/nesterov/jax/submission.py index eeb3940e1..35cebba1f 100644 --- a/baselines/nesterov/jax/submission.py +++ b/baselines/nesterov/jax/submission.py @@ -195,10 +195,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/nesterov/pytorch/submission.py b/baselines/nesterov/pytorch/submission.py index 7031e9888..45feb8645 100644 --- a/baselines/nesterov/pytorch/submission.py +++ b/baselines/nesterov/pytorch/submission.py @@ -148,10 +148,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/sam/jax/submission.py b/baselines/sam/jax/submission.py index bcf9362b6..85b3d7441 100644 --- a/baselines/sam/jax/submission.py +++ b/baselines/sam/jax/submission.py @@ -248,10 +248,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/sam/pytorch/submission.py b/baselines/sam/pytorch/submission.py index 6c2d553e8..2cab75972 100644 --- a/baselines/sam/pytorch/submission.py +++ b/baselines/sam/pytorch/submission.py @@ -220,10 +220,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': diff --git a/baselines/shampoo/jax/submission.py b/baselines/shampoo/jax/submission.py index c873c98e1..cb062faf3 100644 --- a/baselines/shampoo/jax/submission.py +++ b/baselines/shampoo/jax/submission.py @@ -163,10 +163,6 @@ def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': return 262_144 - elif workload_name == 'criteo1tb_layernorm': - return 262_144 - elif workload_name == 'criteo1tb_resnet': - return 262_144 elif workload_name == 'fastmri': return 32 elif workload_name == 'imagenet_resnet': From d46d5e02879d96fc1d40a4f02d887aab074cc287 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 20:20:18 +0000 Subject: [PATCH 025/169] add helper fn for get_baseworkload_name --- algorithmic_efficiency/workloads/workloads.py | 8 ++++++++ submission_runner.py | 6 +++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index d86db0643..c809e8234 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -57,6 +57,14 @@ 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, } +BASE_WORKLOADS = ['criteo1tb', 'fastmri', ' imagenet_resnet', 'imagenet_vit', + 'librispeech_conformer', 'librispeech_deepspeech', + 'ogbg', 'wmt'] + +def get_base_workload_name(workload_name): + for base_workload_name in BASE_WORKLOADS: + if base_workload_name in workload_name: + return base_workload_name def convert_filepath_to_module(path: str): base, extension = os.path.splitext(path) diff --git a/submission_runner.py b/submission_runner.py index 12494cd6e..37d89d136 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -487,7 +487,11 @@ def score_submission_on_workload(workload: spec.Workload, init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection - global_batch_size = submission_module.get_batch_size(workload_name) + try: + global_batch_size = submission_module.get_batch_size(workload_name) + except: + base_workload_name = workloads.get_base_workload_name(workload_name) + global_batch_size = submission_module.get_batch_size(base_workload_name) # n_gpus has to be set here, because we cannot call the first Jax operation # before pytorch_init(). n_gpus = max(N_GPUS, jax.local_device_count()) From 7ddffd3615d31291a13169054b9e142aec575e54 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:24:17 +0000 Subject: [PATCH 026/169] fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 72cc22f29..2f3b88a6c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -46,6 +46,7 @@ def __init__(self, dropout_rate=0.0, use_layer_norm=False): # use_layer_norm is unused. del use_layer_norm + super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features self.num_sparse_features = num_sparse_features From 6b2ed1f802fbac6b58e986499903a8ead4d51ff6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:32:39 +0000 Subject: [PATCH 027/169] modify conformer_resnet model --- .../criteo1tb/criteo1tb_pytorch/models.py | 69 ++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 2f3b88a6c..9b2da7144 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -70,49 +70,54 @@ def __init__(self, self.embedding_table_chucks.append(chunk) # bottom mlp - self.bot_layers = [] + bottom_mlp_layers = [] input_dim = self.num_dense_features for dense_dim in self.mlp_bottom_dims: - block = [] - block.append(nn.Linear(input_dim, dense_dim)) - block.append(nn.ReLU(inplace=True)) - self.bot_layers.append(nn.Sequential(*block)) + bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) + bottom_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim - for layer in self.bot_layers: - for module in layer.modules(): - if isinstance(module, nn.Linear): - limit = math.sqrt(6. / (module.in_features + module.out_features)) - nn.init.uniform_(module.weight.data, -limit, limit) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + self.bot_mlp = nn.Sequential(*bottom_mlp_layers) + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) - # top mlp - # TODO (JB): Write down the formula here instead of the constant. + + # TODO: Write down the formula here instead of the constant. input_dims = 506 - self.top_layers = [] + top_mlp_layers = [] num_layers_top = len(self.mlp_top_dims) for layer_idx, fan_out in enumerate(self.mlp_top_dims): - block = [] - fan_in = (input_dims + self.embed_dim) if layer_idx == 0 \ + fan_in = input_dims if layer_idx == 0 \ else self.mlp_top_dims[layer_idx - 1] - block.append(nn.Linear(fan_in, fan_out)) + top_mlp_layers.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): - block.append(nn.ReLU(inplace=True)) + top_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - block.append(nn.Dropout(p=dropout_rate)) - self.top_layers.append(nn.Sequential(*block)) - for layer in self.top_layers: - for module in layer.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_( - module.weight.data, - 0., - math.sqrt(2. / (module.in_features + module.out_features))) - nn.init.normal_(module.bias.data, - 0., - math.sqrt(1. / module.out_features)) + top_mlp_layers.append(nn.Dropout(p=dropout_rate)) + self.top_mlp = nn.Sequential(*top_mlp_layers) + if use_layer_norm: + self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) + else: + self.embed_ln = None + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) def forward(self, x): # Todo (kasimbeg): add residual layer From cd2d6727cdf5dc6a3f3415dd8e466e83bdd6d6bb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:38:53 +0000 Subject: [PATCH 028/169] ln --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 9b2da7144..ceaa7a730 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -75,8 +75,6 @@ def __init__(self, for dense_dim in self.mlp_bottom_dims: bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) bottom_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): @@ -99,16 +97,10 @@ def __init__(self, top_mlp_layers.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): top_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_layers.append(nn.Dropout(p=dropout_rate)) self.top_mlp = nn.Sequential(*top_mlp_layers) - if use_layer_norm: - self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) - else: - self.embed_ln = None for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( From f0a369ad0783baba002772bc69d0521cf34ff727 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:42:20 +0000 Subject: [PATCH 029/169] fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index ceaa7a730..315ffb889 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -112,7 +112,6 @@ def __init__(self, math.sqrt(1. / module.out_features)) def forward(self, x): - # Todo (kasimbeg): add residual layer batch_size = x.shape[0] dense_features, sparse_features = torch.split( From 01d668c4427df2f53dd11c4f48cf4757f4302280 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:45:38 +0000 Subject: [PATCH 030/169] fix --- .../criteo1tb/criteo1tb_pytorch/models.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 315ffb889..1c4c09e3c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -35,17 +35,15 @@ class DLRMResNet(nn.Module): mlp_top_dims: dimensions of dense layers of the top mlp. embed_dim: embedding dimension. """ - def __init__(self, vocab_size, num_dense_features=13, num_sparse_features=26, - mlp_bottom_dims=(256, 256, 256), - mlp_top_dims=(256, 256, 256, 256, 1), + mlp_bottom_dims=(512, 256, 128), + mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, dropout_rate=0.0, - use_layer_norm=False): # use_layer_norm is unused. - del use_layer_norm + use_layer_norm=False): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -69,12 +67,13 @@ def __init__(self, self.register_parameter(f'embedding_chunk_{i}', chunk) self.embedding_table_chucks.append(chunk) - # bottom mlp bottom_mlp_layers = [] input_dim = self.num_dense_features for dense_dim in self.mlp_bottom_dims: bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) bottom_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): @@ -97,10 +96,16 @@ def __init__(self, top_mlp_layers.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): top_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_layers.append(nn.Dropout(p=dropout_rate)) self.top_mlp = nn.Sequential(*top_mlp_layers) + if use_layer_norm: + self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) + else: + self.embed_ln = None for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( @@ -127,6 +132,8 @@ def forward(self, x): embedded_sparse = embedding_table[idx_lookup] embedded_sparse = torch.reshape(embedded_sparse, [batch_size, -1, self.embed_dim]) + if self.embed_ln: + embedded_sparse = self.embed_ln(embedded_sparse) # Dot product interactions. concatenated_dense = self.dot_interact( dense_features=embedded_dense, sparse_features=embedded_sparse) From 910e9741b53192cfe52ff62518394276efb16a54 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:47:16 +0000 Subject: [PATCH 031/169] fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 1c4c09e3c..300662d5c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -25,8 +25,8 @@ def forward(self, dense_features, sparse_features): return torch.cat((dense_features, interactions_flat), dim=1) -class DLRMResNet(nn.Module): - """Define a DLRM-ResNet model. +class DlrmResnet(nn.Module): + """Define a DLRM-Small model. Parameters: vocab_size: vocab size of embedding table. @@ -35,6 +35,7 @@ class DLRMResNet(nn.Module): mlp_top_dims: dimensions of dense layers of the top mlp. embed_dim: embedding dimension. """ + def __init__(self, vocab_size, num_dense_features=13, From 063c2293c8db30daab0d4487a5e6380076ac0f79 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:50:59 +0000 Subject: [PATCH 032/169] debugging --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 300662d5c..6872f32c5 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -25,7 +25,7 @@ def forward(self, dense_features, sparse_features): return torch.cat((dense_features, interactions_flat), dim=1) -class DlrmResnet(nn.Module): +class DlrmResNet(nn.Module): """Define a DLRM-Small model. Parameters: From 2bf1988f3f6c255d703ed741172f5537ec6eeb3b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:52:54 +0000 Subject: [PATCH 033/169] fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 6872f32c5..3e46a7911 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -25,7 +25,7 @@ def forward(self, dense_features, sparse_features): return torch.cat((dense_features, interactions_flat), dim=1) -class DlrmResNet(nn.Module): +class DLRMResNet(nn.Module): """Define a DLRM-Small model. Parameters: From 2c2a7a9b4c6b74eb026cf8436f479f7327ea0e40 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 16 Nov 2023 21:58:46 +0000 Subject: [PATCH 034/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 4 ++-- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 693573064..a66822815 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -178,8 +178,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - mlp_bottom_dims = (256, 256, 256) - mlp_top_dims = (256, 256, 256, 256, 1) + # mlp_bottom_dims = (256, 256, 256) + # mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index c049ac1de..d903236aa 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -261,8 +261,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - mlp_bottom_dims = (256, 256, 256) - mlp_top_dims = (256, 256, 256, 256, 1) + # mlp_bottom_dims = (256, 256, 256) + # mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: From e698bc79e96928dbeec0c514eb56d6fe4223ef33 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:01:51 +0000 Subject: [PATCH 035/169] resnet block --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 3e46a7911..cd134c77c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -5,7 +5,14 @@ import torch from torch import nn - +class ResNetBlock(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, x): + return self.module(x) + x + + class DotInteract(nn.Module): """Performs feature interaction operation between dense or sparse features.""" From 9fd4efa964cda23c34f81025b2d2d8ad5e698121 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:08:00 +0000 Subject: [PATCH 036/169] add resnet block to criteo resnet variant --- .../criteo1tb/criteo1tb_pytorch/models.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index cd134c77c..66304e937 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -5,14 +5,17 @@ import torch from torch import nn + class ResNetBlock(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - def forward(self, x): - return self.module(x) + x - - + """Resnet block""""" + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, x): + return self.module(x) + x + + class DotInteract(nn.Module): """Performs feature interaction operation between dense or sparse features.""" @@ -202,7 +205,7 @@ def __init__(self, if use_layer_norm: bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bottom_mlp_layers) + self.bot_mlp = nn.Sequential([ResNetBlock(layer) for layer in bottom_mlp_layers]) for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): limit = math.sqrt(6. / (module.in_features + module.out_features)) @@ -228,7 +231,7 @@ def __init__(self, if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*top_mlp_layers) + self.top_mlp = nn.Sequential([ResNetBlock(layer) for layer in top_mlp_layers]) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: From bc78a19ed13259b35a4714725d4cdcb267f3fd99 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:15:06 +0000 Subject: [PATCH 037/169] add back dims --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index a66822815..693573064 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -178,8 +178,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - # mlp_bottom_dims = (256, 256, 256) - # mlp_top_dims = (256, 256, 256, 256, 1) + mlp_bottom_dims = (256, 256, 256) + mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: From b1d22247cd1b34405046bc40ee4f9869904faeb7 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:19:52 +0000 Subject: [PATCH 038/169] fix --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 4 ++-- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 4 ++-- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 316401649..f27d0c73c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -20,8 +20,8 @@ class DLRMResNet(nn.Module): vocab_size: int = 32 * 128 * 1024 # 4_194_304 num_dense_features: int = 13 - mlp_bottom_dims: Sequence[int] = (256, 256, 256) - mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) + mlp_bottom_dims: Sequence[int] = (512, 256, 128) + mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 693573064..a66822815 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -178,8 +178,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - mlp_bottom_dims = (256, 256, 256) - mlp_top_dims = (256, 256, 256, 256, 1) + # mlp_bottom_dims = (256, 256, 256) + # mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 66304e937..13dd4f788 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -45,7 +45,6 @@ class DLRMResNet(nn.Module): mlp_top_dims: dimensions of dense layers of the top mlp. embed_dim: embedding dimension. """ - def __init__(self, vocab_size, num_dense_features=13, From 6a46a6c089d44dd8a765215bb6c3aa0aa024dd1c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:23:05 +0000 Subject: [PATCH 039/169] resnet --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 4 ++-- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index a66822815..693573064 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -178,8 +178,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - # mlp_bottom_dims = (256, 256, 256) - # mlp_top_dims = (256, 256, 256, 256, 1) + mlp_bottom_dims = (256, 256, 256) + mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index d903236aa..c049ac1de 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -261,8 +261,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - # mlp_bottom_dims = (256, 256, 256) - # mlp_top_dims = (256, 256, 256, 256, 1) + mlp_bottom_dims = (256, 256, 256) + mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: From 83c4adefef9bde3ba4ff888ea5040ce94e7f13ac Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:24:51 +0000 Subject: [PATCH 040/169] comment out pytorch --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index c049ac1de..d903236aa 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -261,8 +261,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - mlp_bottom_dims = (256, 256, 256) - mlp_top_dims = (256, 256, 256, 256, 1) + # mlp_bottom_dims = (256, 256, 256) + # mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: From e60a15962d464dbb370c45eaff04fc0683add2ac Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:46:23 +0000 Subject: [PATCH 041/169] resnet fix' --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 4 ++-- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 -- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 4 ++-- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 2 -- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index f27d0c73c..41117ea45 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -20,8 +20,8 @@ class DLRMResNet(nn.Module): vocab_size: int = 32 * 128 * 1024 # 4_194_304 num_dense_features: int = 13 - mlp_bottom_dims: Sequence[int] = (512, 256, 128) - mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) + mlp_bottom_dims: Sequence[int] = (256, 256, 256) + mlp_top_dims: Sequence[int] = (256, 1256, 256, 1) embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 693573064..3fc7f849b 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -178,8 +178,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - mlp_bottom_dims = (256, 256, 256) - mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 13dd4f788..b6a838c81 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -49,8 +49,8 @@ def __init__(self, vocab_size, num_dense_features=13, num_sparse_features=26, - mlp_bottom_dims=(512, 256, 128), - mlp_top_dims=(1024, 1024, 512, 256, 1), + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), embed_dim=128, dropout_rate=0.0, use_layer_norm=False): diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index d903236aa..bcd5c1fd8 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -261,8 +261,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - # mlp_bottom_dims = (256, 256, 256) - # mlp_top_dims = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: From e95fe17b373e0021440edf7301a518088cb8ce8d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:48:29 +0000 Subject: [PATCH 042/169] fix --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 41117ea45..34376701d 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -21,7 +21,7 @@ class DLRMResNet(nn.Module): vocab_size: int = 32 * 128 * 1024 # 4_194_304 num_dense_features: int = 13 mlp_bottom_dims: Sequence[int] = (256, 256, 256) - mlp_top_dims: Sequence[int] = (256, 1256, 256, 1) + mlp_top_dims: Sequence[int] = (256, 256, 256, 1) embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. From 1e5c70908d71895f4232df81392636605db74953 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:48:52 +0000 Subject: [PATCH 043/169] fix --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 34376701d..316401649 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -21,7 +21,7 @@ class DLRMResNet(nn.Module): vocab_size: int = 32 * 128 * 1024 # 4_194_304 num_dense_features: int = 13 mlp_bottom_dims: Sequence[int] = (256, 256, 256) - mlp_top_dims: Sequence[int] = (256, 256, 256, 1) + mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. From 04da955a63f1d3fb868ff2ad07b687130f9fb2df Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 01:54:47 +0000 Subject: [PATCH 044/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 316401649..f0e1113fa 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -35,7 +35,7 @@ def __call__(self, x, train): mlp_bottom_dims = self.mlp_bottom_dims bot_mlp_input = nn.Dense( - mlp_bottom_dims[0], + 256, kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal( stddev=jnp.sqrt(1.0 / mlp_bottom_dims[0])), @@ -45,7 +45,7 @@ def __call__(self, x, train): for dense_dim in mlp_bottom_dims[1:]: x = nn.Dense( - dense_dim, + 256, kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), )( From 0b3c7a96797a3f1870b71c2813262393b37dc4d3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 02:02:31 +0000 Subject: [PATCH 045/169] debugging --- .../criteo1tb/criteo1tb_jax/models.py | 164 +++++++++--------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index f0e1113fa..0eec83eee 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -130,85 +130,85 @@ def dot_interact(concat_features): return activations -class DlrmSmall(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - vocab_size: int = 32 * 128 * 1024 # 4_194_304. - num_dense_features: int = 13 - mlp_bottom_dims: Sequence[int] = (512, 256, 128) - mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) - embed_dim: int = 128 - dropout_rate: float = 0.0 - use_layer_norm: bool = False - - @nn.compact - def __call__(self, x, train): - bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) - cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - - # Bottom MLP. - for dense_dim in self.mlp_bottom_dims: - bot_mlp_input = nn.Dense( - dense_dim, - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), - )( - bot_mlp_input) - bot_mlp_input = nn.relu(bot_mlp_input) - if self.use_layer_norm: - bot_mlp_input = nn.LayerNorm()(bot_mlp_input) - bot_mlp_output = bot_mlp_input - batch_size = bot_mlp_output.shape[0] - feature_stack = jnp.reshape(bot_mlp_output, - [batch_size, -1, self.embed_dim]) - - # Embedding table look-up. - idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - - def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / - jnp.sqrt(self.vocab_size)) - - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) - - idx_lookup = jnp.reshape(idx_lookup, [-1]) - embed_features = embedding_table[idx_lookup] - embed_features = jnp.reshape(embed_features, - [batch_size, -1, self.embed_dim]) - if self.use_layer_norm: - embed_features = nn.LayerNorm()(embed_features) - feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) - dot_interact_output = dot_interact(concat_features=feature_stack) - top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], - axis=-1) - mlp_input_dim = top_mlp_input.shape[1] - mlp_top_dims = self.mlp_top_dims - num_layers_top = len(mlp_top_dims) - for layer_idx, fan_out in enumerate(mlp_top_dims): - fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] - top_mlp_input = nn.Dense( - fan_out, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( - top_mlp_input) - if layer_idx < (num_layers_top - 1): - top_mlp_input = nn.relu(top_mlp_input) - if self.use_layer_norm: - top_mlp_input = nn.LayerNorm()(top_mlp_input) - if (self.dropout_rate is not None and self.dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_input = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)( - top_mlp_input) - logits = top_mlp_input - return logits +# class DlrmSmall(nn.Module): +# """Define a DLRM-Small model. + +# Parameters: +# vocab_size: vocab size of embedding table. +# num_dense_features: number of dense features as the bottom mlp input. +# mlp_bottom_dims: dimensions of dense layers of the bottom mlp. +# mlp_top_dims: dimensions of dense layers of the top mlp. +# embed_dim: embedding dimension. +# """ + +# vocab_size: int = 32 * 128 * 1024 # 4_194_304. +# num_dense_features: int = 13 +# mlp_bottom_dims: Sequence[int] = (512, 256, 128) +# mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) +# embed_dim: int = 128 +# dropout_rate: float = 0.0 +# use_layer_norm: bool = False + +# @nn.compact +# def __call__(self, x, train): +# bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) +# cat_features = jnp.asarray(cat_features, dtype=jnp.int32) + +# # Bottom MLP. +# for dense_dim in self.mlp_bottom_dims: +# bot_mlp_input = nn.Dense( +# dense_dim, +# kernel_init=jnn.initializers.glorot_uniform(), +# bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), +# )( +# bot_mlp_input) +# bot_mlp_input = nn.relu(bot_mlp_input) +# if self.use_layer_norm: +# bot_mlp_input = nn.LayerNorm()(bot_mlp_input) +# bot_mlp_output = bot_mlp_input +# batch_size = bot_mlp_output.shape[0] +# feature_stack = jnp.reshape(bot_mlp_output, +# [batch_size, -1, self.embed_dim]) + +# # Embedding table look-up. +# idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + +# def scaled_init(key, shape, dtype=jnp.float_): +# return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / +# jnp.sqrt(self.vocab_size)) + +# embedding_table = self.param('embedding_table', +# scaled_init, [self.vocab_size, self.embed_dim]) + +# idx_lookup = jnp.reshape(idx_lookup, [-1]) +# embed_features = embedding_table[idx_lookup] +# embed_features = jnp.reshape(embed_features, +# [batch_size, -1, self.embed_dim]) +# if self.use_layer_norm: +# embed_features = nn.LayerNorm()(embed_features) +# feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) +# dot_interact_output = dot_interact(concat_features=feature_stack) +# top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], +# axis=-1) +# mlp_input_dim = top_mlp_input.shape[1] +# mlp_top_dims = self.mlp_top_dims +# num_layers_top = len(mlp_top_dims) +# for layer_idx, fan_out in enumerate(mlp_top_dims): +# fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] +# top_mlp_input = nn.Dense( +# fan_out, +# kernel_init=jnn.initializers.normal( +# stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), +# bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( +# top_mlp_input) +# if layer_idx < (num_layers_top - 1): +# top_mlp_input = nn.relu(top_mlp_input) +# if self.use_layer_norm: +# top_mlp_input = nn.LayerNorm()(top_mlp_input) +# if (self.dropout_rate is not None and self.dropout_rate > 0.0 and +# layer_idx == num_layers_top - 2): +# top_mlp_input = nn.Dropout( +# rate=self.dropout_rate, deterministic=not train)( +# top_mlp_input) +# logits = top_mlp_input +# return logits From 76169a27a444e1b5b7b0acca5ae9c65984301c04 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 02:10:44 +0000 Subject: [PATCH 046/169] mlp dims --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 4 ++-- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 4 +++- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 2 ++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 0eec83eee..69b3bcfae 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -35,7 +35,7 @@ def __call__(self, x, train): mlp_bottom_dims = self.mlp_bottom_dims bot_mlp_input = nn.Dense( - 256, + mlp_bottom_dims[0], kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal( stddev=jnp.sqrt(1.0 / mlp_bottom_dims[0])), @@ -45,7 +45,7 @@ def __call__(self, x, train): for dense_dim in mlp_bottom_dims[1:]: x = nn.Dense( - 256, + dense_dim, kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), )( diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 3fc7f849b..947205a9e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -178,7 +178,9 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - + mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) + mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) + @property def use_resnet(self) -> bool: """Whether or not to use residual connections in the model.""" diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index bcd5c1fd8..eab47738e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -261,6 +261,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): + mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) + mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) @property def use_resnet(self) -> bool: From cf7f2212cc57ca15700c2c8533f1e69b23e49f8d Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 02:24:03 +0000 Subject: [PATCH 047/169] variant fix --- .../criteo1tb/criteo1tb_jax/models.py | 166 +++++++++--------- .../criteo1tb/criteo1tb_jax/workload.py | 3 +- .../criteo1tb/criteo1tb_pytorch/models.py | 2 +- .../criteo1tb/criteo1tb_pytorch/workload.py | 1 + 4 files changed, 87 insertions(+), 85 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 69b3bcfae..fda3ed850 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -22,7 +22,7 @@ class DLRMResNet(nn.Module): num_dense_features: int = 13 mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) - embed_dim: int = 128 + embed_dim: int = 256 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. @@ -130,85 +130,85 @@ def dot_interact(concat_features): return activations -# class DlrmSmall(nn.Module): -# """Define a DLRM-Small model. - -# Parameters: -# vocab_size: vocab size of embedding table. -# num_dense_features: number of dense features as the bottom mlp input. -# mlp_bottom_dims: dimensions of dense layers of the bottom mlp. -# mlp_top_dims: dimensions of dense layers of the top mlp. -# embed_dim: embedding dimension. -# """ - -# vocab_size: int = 32 * 128 * 1024 # 4_194_304. -# num_dense_features: int = 13 -# mlp_bottom_dims: Sequence[int] = (512, 256, 128) -# mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) -# embed_dim: int = 128 -# dropout_rate: float = 0.0 -# use_layer_norm: bool = False - -# @nn.compact -# def __call__(self, x, train): -# bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) -# cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - -# # Bottom MLP. -# for dense_dim in self.mlp_bottom_dims: -# bot_mlp_input = nn.Dense( -# dense_dim, -# kernel_init=jnn.initializers.glorot_uniform(), -# bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), -# )( -# bot_mlp_input) -# bot_mlp_input = nn.relu(bot_mlp_input) -# if self.use_layer_norm: -# bot_mlp_input = nn.LayerNorm()(bot_mlp_input) -# bot_mlp_output = bot_mlp_input -# batch_size = bot_mlp_output.shape[0] -# feature_stack = jnp.reshape(bot_mlp_output, -# [batch_size, -1, self.embed_dim]) - -# # Embedding table look-up. -# idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - -# def scaled_init(key, shape, dtype=jnp.float_): -# return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / -# jnp.sqrt(self.vocab_size)) - -# embedding_table = self.param('embedding_table', -# scaled_init, [self.vocab_size, self.embed_dim]) - -# idx_lookup = jnp.reshape(idx_lookup, [-1]) -# embed_features = embedding_table[idx_lookup] -# embed_features = jnp.reshape(embed_features, -# [batch_size, -1, self.embed_dim]) -# if self.use_layer_norm: -# embed_features = nn.LayerNorm()(embed_features) -# feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) -# dot_interact_output = dot_interact(concat_features=feature_stack) -# top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], -# axis=-1) -# mlp_input_dim = top_mlp_input.shape[1] -# mlp_top_dims = self.mlp_top_dims -# num_layers_top = len(mlp_top_dims) -# for layer_idx, fan_out in enumerate(mlp_top_dims): -# fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] -# top_mlp_input = nn.Dense( -# fan_out, -# kernel_init=jnn.initializers.normal( -# stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), -# bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( -# top_mlp_input) -# if layer_idx < (num_layers_top - 1): -# top_mlp_input = nn.relu(top_mlp_input) -# if self.use_layer_norm: -# top_mlp_input = nn.LayerNorm()(top_mlp_input) -# if (self.dropout_rate is not None and self.dropout_rate > 0.0 and -# layer_idx == num_layers_top - 2): -# top_mlp_input = nn.Dropout( -# rate=self.dropout_rate, deterministic=not train)( -# top_mlp_input) -# logits = top_mlp_input -# return logits +class DlrmSmall(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + vocab_size: int = 32 * 128 * 1024 # 4_194_304. + num_dense_features: int = 13 + mlp_bottom_dims: Sequence[int] = (512, 256, 128) + mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) + embed_dim: int = 128 + dropout_rate: float = 0.0 + use_layer_norm: bool = False + + @nn.compact + def __call__(self, x, train): + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) + cat_features = jnp.asarray(cat_features, dtype=jnp.int32) + + # Bottom MLP. + for dense_dim in self.mlp_bottom_dims: + bot_mlp_input = nn.Dense( + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), + )( + bot_mlp_input) + bot_mlp_input = nn.relu(bot_mlp_input) + if self.use_layer_norm: + bot_mlp_input = nn.LayerNorm()(bot_mlp_input) + bot_mlp_output = bot_mlp_input + batch_size = bot_mlp_output.shape[0] + feature_stack = jnp.reshape(bot_mlp_output, + [batch_size, -1, self.embed_dim]) + + # Embedding table look-up. + idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + + def scaled_init(key, shape, dtype=jnp.float_): + return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / + jnp.sqrt(self.vocab_size)) + + embedding_table = self.param('embedding_table', + scaled_init, [self.vocab_size, self.embed_dim]) + + idx_lookup = jnp.reshape(idx_lookup, [-1]) + embed_features = embedding_table[idx_lookup] + embed_features = jnp.reshape(embed_features, + [batch_size, -1, self.embed_dim]) + if self.use_layer_norm: + embed_features = nn.LayerNorm()(embed_features) + feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) + dot_interact_output = dot_interact(concat_features=feature_stack) + top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], + axis=-1) + mlp_input_dim = top_mlp_input.shape[1] + mlp_top_dims = self.mlp_top_dims + num_layers_top = len(mlp_top_dims) + for layer_idx, fan_out in enumerate(mlp_top_dims): + fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] + top_mlp_input = nn.Dense( + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( + top_mlp_input) + if layer_idx < (num_layers_top - 1): + top_mlp_input = nn.relu(top_mlp_input) + if self.use_layer_norm: + top_mlp_input = nn.LayerNorm()(top_mlp_input) + if (self.dropout_rate is not None and self.dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + top_mlp_input = nn.Dropout( + rate=self.dropout_rate, deterministic=not train)( + top_mlp_input) + logits = top_mlp_input + return logits diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 947205a9e..fed10d536 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -180,7 +180,8 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) - + embed_dim: int = 256 + @property def use_resnet(self) -> bool: """Whether or not to use residual connections in the model.""" diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index b6a838c81..2b3b1be10 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -51,7 +51,7 @@ def __init__(self, num_sparse_features=26, mlp_bottom_dims=(256, 256, 256), mlp_top_dims=(256, 256, 256, 256, 1), - embed_dim=128, + embed_dim=256, dropout_rate=0.0, use_layer_norm=False): super().__init__() diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index eab47738e..3671572ef 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -263,6 +263,7 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) + embed_dim: int = 256 @property def use_resnet(self) -> bool: From 6fe09017cacf62b1716428b4a6018166caf61d7b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 02:32:53 +0000 Subject: [PATCH 048/169] fix dlrm variant --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 2b3b1be10..1ca847621 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -216,7 +216,7 @@ def __init__(self, self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) # TODO: Write down the formula here instead of the constant. - input_dims = 506 + input_dims = 634 top_mlp_layers = [] num_layers_top = len(self.mlp_top_dims) for layer_idx, fan_out in enumerate(self.mlp_top_dims): From 0723d825f5af84e8f2b74f2a7435cf797b6e0d5c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 02:35:28 +0000 Subject: [PATCH 049/169] dlrm fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 1ca847621..21b934d95 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -97,7 +97,7 @@ def __init__(self, self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) # TODO: Write down the formula here instead of the constant. - input_dims = 506 + input_dims = 634 top_mlp_layers = [] num_layers_top = len(self.mlp_top_dims) for layer_idx, fan_out in enumerate(self.mlp_top_dims): @@ -216,7 +216,7 @@ def __init__(self, self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) # TODO: Write down the formula here instead of the constant. - input_dims = 634 + input_dims = 506 top_mlp_layers = [] num_layers_top = len(self.mlp_top_dims) for layer_idx, fan_out in enumerate(self.mlp_top_dims): From d56269b152a43bbfb517094ce341a35cf55b8b18 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 03:03:03 +0000 Subject: [PATCH 050/169] debugging --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 3671572ef..9937c0a4c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -79,6 +79,11 @@ def init_model_fn( model_class = models.DLRMResNet else: model_class = models.DlrmSmall + print(self.vocab_size) + print(self.num_dense_features) + print(self.mlp_bottom_dims) + print(self.mlp_top_dims) + print(self.embed_dim) model = model_class( vocab_size=self.vocab_size, num_dense_features=self.num_dense_features, From 61f052b4c42b588475447ccef1beac8dac7de773 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 03:05:25 +0000 Subject: [PATCH 051/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index fda3ed850..08d8c9016 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -28,6 +28,11 @@ class DLRMResNet(nn.Module): @nn.compact def __call__(self, x, train): + print(self.vocab_size) + print(self.num_dense_features) + print(self.mlp_bottom_dims) + print(self.mlp_top_dims) + print(self.embed_dim) bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) From ff5aa919eaf6a9b93743830412af8d4b7b8be79a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 03:08:19 +0000 Subject: [PATCH 052/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 08d8c9016..e484caf4a 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -69,6 +69,7 @@ def scaled_init(key, shape, dtype=jnp.float_): embed_features = embedding_table[idx_lookup] batch_size = bot_mlp_input.shape[0] + print(jnp.shape(embed_features)) embed_features = jnp.reshape(embed_features, (batch_size, 26 * self.embed_dim)) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) From df48937d1ec5ff5a54e3477e1f6ff71fda159193 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 03:12:29 +0000 Subject: [PATCH 053/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index e484caf4a..23711d2e7 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -69,6 +69,7 @@ def scaled_init(key, shape, dtype=jnp.float_): embed_features = embedding_table[idx_lookup] batch_size = bot_mlp_input.shape[0] + print('embed_features shape') print(jnp.shape(embed_features)) embed_features = jnp.reshape(embed_features, (batch_size, 26 * self.embed_dim)) From 8a574e33b47035dcc85db14d0b3b1a9f498e9953 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 03:14:59 +0000 Subject: [PATCH 054/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 23711d2e7..569f2cbb3 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -71,6 +71,7 @@ def scaled_init(key, shape, dtype=jnp.float_): batch_size = bot_mlp_input.shape[0] print('embed_features shape') print(jnp.shape(embed_features)) + print(jnp.shape(bot_mlp_input)) embed_features = jnp.reshape(embed_features, (batch_size, 26 * self.embed_dim)) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) From 82e42e3ce198cfc612423d5b70d7d0d4ef532548 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 03:22:37 +0000 Subject: [PATCH 055/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 569f2cbb3..b2f11a597 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -78,6 +78,7 @@ def scaled_init(key, shape, dtype=jnp.float_): mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) + print(jnp.shape(top_mlp_input)) top_mlp_input = nn.Dense( mlp_top_dims[0], kernel_init=jnn.initializers.normal( From f2beed94d85d1888dd8c28032d9b6f9555d79bb8 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 03:28:08 +0000 Subject: [PATCH 056/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index fed10d536..957e4eb80 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -178,7 +178,7 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) + mlp_bottom_dims: Tuple[int, int] = (512, 512, 512) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) embed_dim: int = 256 From 17c1190b998d35446dc6aff05c334a21d66c2e13 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 06:29:37 +0000 Subject: [PATCH 057/169] fix --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 957e4eb80..0f8014bb1 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -2,11 +2,12 @@ import functools from typing import Dict, Optional, Tuple - +from absl import logging from flax import jax_utils import jax import jax.numpy as jnp import numpy as np +import flax.linen as nn from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -99,9 +100,15 @@ def init_model_fn( initial_variables = jax.jit(init_fn)( {'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape, jnp.float32)) + fake_inputs = jnp.ones(input_shape, jnp.float32) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) + tabulate_fn = nn.tabulate(self._model.flax_module, jax.random.PRNGKey(0), + console_kwargs={'force_terminal': False, + 'force_jupyter': False, + 'width': 240},) + logging.info(tabulate_fn(*fake_inputs, train=False,)) return jax_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -178,7 +185,7 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): - mlp_bottom_dims: Tuple[int, int] = (512, 512, 512) + mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) embed_dim: int = 256 From e5a31f57e25bed0219f3ee633bb9dfd6823ae8a9 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 06:31:08 +0000 Subject: [PATCH 058/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 0f8014bb1..4ec8021c0 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -104,7 +104,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - tabulate_fn = nn.tabulate(self._model.flax_module, jax.random.PRNGKey(0), + tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240},) From 7c98352611d524ccd8933bf2f49f5030550edf97 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 06:34:28 +0000 Subject: [PATCH 059/169] debuggingg --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 4ec8021c0..da29d9aef 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -108,7 +108,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240},) - logging.info(tabulate_fn(*fake_inputs, train=False,)) + logging.info(tabulate_fn(*fake_inputs)) return jax_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: From c9facb23fe32eea739f429d7737926e74c89ce52 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 06:55:26 +0000 Subject: [PATCH 060/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index da29d9aef..c40ad1168 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -108,7 +108,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240},) - logging.info(tabulate_fn(*fake_inputs)) + logging.info(tabulate_fn(fake_inputs)) return jax_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: From 549be97a2985fd6b207acd39e758f36bb2f8cdf4 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 06:56:44 +0000 Subject: [PATCH 061/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index c40ad1168..67795c9bd 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -108,7 +108,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240},) - logging.info(tabulate_fn(fake_inputs)) + logging.info(tabulate_fn(fake_inputs, train=False)) return jax_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: From 6967445d4c1f83069850212765d46285608c55a9 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 06:59:17 +0000 Subject: [PATCH 062/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 67795c9bd..a6e4a43d4 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -108,7 +108,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240},) - logging.info(tabulate_fn(fake_inputs, train=False)) + print(tabulate_fn(fake_inputs, train=False)) return jax_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: From 5d237ebef15729faca70f5fad98c085c1e82bb0b Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 07:11:24 +0000 Subject: [PATCH 063/169] debug --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 21b934d95..3428081b1 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -85,7 +85,7 @@ def __init__(self, if use_layer_norm: bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim - self.bot_mlp = nn.Sequential(*bottom_mlp_layers) + self.bot_mlp = nn.Sequential([ResNetBlock(layer) for layer in bottom_mlp_layers]) for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): limit = math.sqrt(6. / (module.in_features + module.out_features)) @@ -111,7 +111,7 @@ def __init__(self, if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*top_mlp_layers) + self.top_mlp = nn.Sequential([ResNetBlock(layer) for layer in top_mlp_layers]) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: @@ -204,7 +204,7 @@ def __init__(self, if use_layer_norm: bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim - self.bot_mlp = nn.Sequential([ResNetBlock(layer) for layer in bottom_mlp_layers]) + self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): limit = math.sqrt(6. / (module.in_features + module.out_features)) @@ -230,7 +230,7 @@ def __init__(self, if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential([ResNetBlock(layer) for layer in top_mlp_layers]) + self.top_mlp = nn.Sequential(*top_mlp_layers) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: From 9a132b153e404e3978e9e824bf014029843dc3f6 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 17 Nov 2023 07:19:11 +0000 Subject: [PATCH 064/169] debugging --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 3428081b1..7f36789cd 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -85,7 +85,7 @@ def __init__(self, if use_layer_norm: bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim - self.bot_mlp = nn.Sequential([ResNetBlock(layer) for layer in bottom_mlp_layers]) + self.bot_mlp = nn.Sequential(*[ResNetBlock(layer) for layer in bottom_mlp_layers]) for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): limit = math.sqrt(6. / (module.in_features + module.out_features)) @@ -111,7 +111,7 @@ def __init__(self, if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential([ResNetBlock(layer) for layer in top_mlp_layers]) + self.top_mlp = nn.Sequential(*[ResNetBlock(layer) for layer in top_mlp_layers]) if use_layer_norm: self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) else: From 89dea7220fc784ae0c935fa02fce49d81c729bef Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 00:52:39 +0000 Subject: [PATCH 065/169] debugging --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index d903236aa..97f2cb078 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -95,6 +95,7 @@ def init_model_fn( model = DDP(model, device_ids=[RANK], output_device=RANK) else: model = torch.nn.DataParallel(model) + print(model) return model, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: From fdca6c576e952c9517bc98fc6709b0f38d7beb25 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 01:33:22 +0000 Subject: [PATCH 066/169] debugging --- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 761cc47bc..af1eb57f4 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -24,6 +24,10 @@ def key_transform(k): continue if 'Embedding' in i: return ('embedding_table',) + if 'ResNetBlock' in i: + i = i.replace('ResNetBlock', 'Dense') + name, count = i.split('_') + i = name + '_' + str(s_count * 3 + int(count)) if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') From 29835e363a5833dca29f5ee930aa832bb6fd81eb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 01:36:15 +0000 Subject: [PATCH 067/169] debugging --- tests/modeldiffs/torch2jax_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 1926ab0cc..4fdca517e 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -77,6 +77,8 @@ def key_transform(self, k_transform_fn): } def value_transform(self, v_transform_fn): + for k in self.pytorch_sd: + print(k) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From 3a859eb522e3053fe4289ee4ba17dbb5d76fd0a4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:04:15 +0000 Subject: [PATCH 068/169] clarify output of diff test --- tests/modeldiffs/diff.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index bc53de875..8dc948fc3 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -56,5 +56,8 @@ def out_diff(jax_workload, out_p = out_transform(out_p) out_j = out_transform(out_j) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) + max_diff = np.abs(out_p.detach().numpy() - np.array(out_j)).max() + min_diff = np.abs(out_p.detach().numpy() - np.array(out_j)).min() + + logging.info(d'Max fprop difference between jax and pytorch: {max_diff}') + logging.info(d'Min fprop difference between jax and pytorch: {min_diff}') From 78298637ee1dc82672b5771820c1d4238d9fd16b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:14:45 +0000 Subject: [PATCH 069/169] key transform --- tests/modeldiffs/criteo1tb_resnet/compare.py | 5 ++++- tests/modeldiffs/torch2jax_utils.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index af1eb57f4..582cb2412 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -16,9 +16,12 @@ def key_transform(k): + print('key transform: ') new_key = [] s_count = None + print(k) for i in k: + print(i) if 'Sequential' in i: s_count = int(i.split('_')[1]) continue @@ -34,7 +37,7 @@ def key_transform(k): i = name + '_' + str(s_count * 3 + int(count)) elif 'weight' in i: i = i.replace('weight', 'kernel') - + print(i) new_key.append(i) return tuple(new_key) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 4fdca517e..46502987c 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -78,7 +78,6 @@ def key_transform(self, k_transform_fn): def value_transform(self, v_transform_fn): for k in self.pytorch_sd: - print(k) self.pytorch_sd = { k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) for k in self.pytorch_sd From a47e6b14155ed162a3f8a9d2ce7b79d0fa693603 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:15:47 +0000 Subject: [PATCH 070/169] syntaxl --- tests/modeldiffs/diff.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index 8dc948fc3..edad14268 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -59,5 +59,5 @@ def out_diff(jax_workload, max_diff = np.abs(out_p.detach().numpy() - np.array(out_j)).max() min_diff = np.abs(out_p.detach().numpy() - np.array(out_j)).min() - logging.info(d'Max fprop difference between jax and pytorch: {max_diff}') - logging.info(d'Min fprop difference between jax and pytorch: {min_diff}') + logging.info(f'Max fprop difference between jax and pytorch: {max_diff}') + logging.info(f'Min fprop difference between jax and pytorch: {min_diff}') From f75944df2849bddf1a700274dfd3324d2e07c36f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:17:25 +0000 Subject: [PATCH 071/169] diff --- tests/modeldiffs/torch2jax_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 46502987c..07f7cc360 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -78,10 +78,10 @@ def key_transform(self, k_transform_fn): def value_transform(self, v_transform_fn): for k in self.pytorch_sd: - self.pytorch_sd = { - k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) - for k in self.pytorch_sd - } + self.pytorch_sd = { + k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) + for k in self.pytorch_sd + } def sd_transform(self, sd_transform_fn): self.pytorch_sd = sd_transform_fn(self.pytorch_sd) From ef68ba8122bd4d16ab2fa38d430fcdd2882f9e2d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:19:02 +0000 Subject: [PATCH 072/169] logging --- tests/modeldiffs/diff.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index edad14268..a877e28c2 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -1,3 +1,4 @@ +from absl import logging from flax import jax_utils import jax import numpy as np From e096a3ab859b50183a6f945260b1808cc82e6d09 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:21:26 +0000 Subject: [PATCH 073/169] debug --- tests/modeldiffs/criteo1tb/compare.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 9a95f3656..77ac04b12 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -16,9 +16,11 @@ def key_transform(k): + print('key transform: ') new_key = [] s_count = None for i in k: + print(i) if 'Sequential' in i: s_count = int(i.split('_')[1]) continue @@ -32,6 +34,7 @@ def key_transform(k): i = i.replace('weight', 'kernel') new_key.append(i) + print(i) return tuple(new_key) From 0c7a86453fb2a45237b3d9068fd7606319be82f8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:23:32 +0000 Subject: [PATCH 074/169] debug --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 9fce6816e..7957c4869 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -79,11 +79,6 @@ def init_model_fn( model_class = models.DLRMResNet else: model_class = models.DlrmSmall - print(self.vocab_size) - print(self.num_dense_features) - print(self.mlp_bottom_dims) - print(self.mlp_top_dims) - print(self.embed_dim) model = model_class( vocab_size=self.vocab_size, num_dense_features=self.num_dense_features, From 271412da4d2915c7938c794f090029566e62bff7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:25:42 +0000 Subject: [PATCH 075/169] debug --- tests/modeldiffs/criteo1tb/compare.py | 4 ++-- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 77ac04b12..a74c41ef7 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -20,7 +20,7 @@ def key_transform(k): new_key = [] s_count = None for i in k: - print(i) + print(f'in transform: {i}') if 'Sequential' in i: s_count = int(i.split('_')[1]) continue @@ -34,7 +34,7 @@ def key_transform(k): i = i.replace('weight', 'kernel') new_key.append(i) - print(i) + print(f'out transform: {i}') return tuple(new_key) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 582cb2412..1a7a2143d 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -21,7 +21,7 @@ def key_transform(k): s_count = None print(k) for i in k: - print(i) + print(f'in transform {i}') if 'Sequential' in i: s_count = int(i.split('_')[1]) continue @@ -37,7 +37,7 @@ def key_transform(k): i = name + '_' + str(s_count * 3 + int(count)) elif 'weight' in i: i = i.replace('weight', 'kernel') - print(i) + print(f'out transform {i}') new_key.append(i) return tuple(new_key) From 650ef436f44518315f9ac08c2276745949da8ba7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:35:07 +0000 Subject: [PATCH 076/169] debugging --- tests/modeldiffs/criteo1tb/compare.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index a74c41ef7..669bccddd 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -19,6 +19,7 @@ def key_transform(k): print('key transform: ') new_key = [] s_count = None + print(k) for i in k: print(f'in transform: {i}') if 'Sequential' in i: From ed1cdbaf17ac5dba7b0d36cda60c412708a92cf0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:43:00 +0000 Subject: [PATCH 077/169] debug --- tests/modeldiffs/criteo1tb/compare.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 669bccddd..1154d2276 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -36,6 +36,7 @@ def key_transform(k): new_key.append(i) print(f'out transform: {i}') + print(f'new key {new_key}') return tuple(new_key) From f59b48d8e084fb81fc313194fa8e0a0d634d9ae4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 02:43:08 +0000 Subject: [PATCH 078/169] debug --- tests/modeldiffs/criteo1tb_resnet/compare.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 1a7a2143d..362b11317 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -39,6 +39,7 @@ def key_transform(k): i = i.replace('weight', 'kernel') print(f'out transform {i}') new_key.append(i) + print(f'new key {new_key}') return tuple(new_key) From 47e71667d6621b87497c3ed79599693325658252 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 03:24:11 +0000 Subject: [PATCH 079/169] debug --- tests/modeldiffs/criteo1tb/compare.py | 1 - tests/modeldiffs/criteo1tb_resnet/compare.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 1154d2276..491cae4c5 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -33,7 +33,6 @@ def key_transform(k): i = name + '_' + str(s_count * 3 + int(count)) elif 'weight' in i: i = i.replace('weight', 'kernel') - new_key.append(i) print(f'out transform: {i}') print(f'new key {new_key}') diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 362b11317..8fbeb09c5 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -19,6 +19,7 @@ def key_transform(k): print('key transform: ') new_key = [] s_count = None + resnet_count = None print(k) for i in k: print(f'in transform {i}') @@ -30,11 +31,12 @@ def key_transform(k): if 'ResNetBlock' in i: i = i.replace('ResNetBlock', 'Dense') name, count = i.split('_') - i = name + '_' + str(s_count * 3 + int(count)) + resnet_count = resnet_count + 1 + continue if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') - i = name + '_' + str(s_count * 3 + int(count)) + i = name + '_' + str(s_count * 3 + int(resnet_count)) elif 'weight' in i: i = i.replace('weight', 'kernel') print(f'out transform {i}') From 39a6c311cd4cd88917cea08273891f156a59f3e4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 03:26:39 +0000 Subject: [PATCH 080/169] debug --- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 8fbeb09c5..6eb7583c4 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -18,8 +18,8 @@ def key_transform(k): print('key transform: ') new_key = [] - s_count = None - resnet_count = None + s_count = 0 + resnet_count = 0 print(k) for i in k: print(f'in transform {i}') From f70dd49f9c6e60038be4c0bb07d962903c8bbc31 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 03:52:01 +0000 Subject: [PATCH 081/169] debug --- tests/modeldiffs/criteo1tb_resnet/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 6eb7583c4..b5b0cfb62 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -29,13 +29,13 @@ def key_transform(k): if 'Embedding' in i: return ('embedding_table',) if 'ResNetBlock' in i: - i = i.replace('ResNetBlock', 'Dense') name, count = i.split('_') resnet_count = resnet_count + 1 continue if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') + print(resnet_count) i = name + '_' + str(s_count * 3 + int(resnet_count)) elif 'weight' in i: i = i.replace('weight', 'kernel') From 0966015bc74c4fc775c055840780d18ce19cbef5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 03:54:47 +0000 Subject: [PATCH 082/169] debug --- tests/modeldiffs/criteo1tb_resnet/compare.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index b5b0cfb62..0253f65bc 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -18,8 +18,8 @@ def key_transform(k): print('key transform: ') new_key = [] - s_count = 0 - resnet_count = 0 + s_count = None + resnet_count = None print(k) for i in k: print(f'in transform {i}') @@ -30,7 +30,7 @@ def key_transform(k): return ('embedding_table',) if 'ResNetBlock' in i: name, count = i.split('_') - resnet_count = resnet_count + 1 + resnet_count = int(count) continue if 'Linear' in i: i = i.replace('Linear', 'Dense') From d7d4638e35ee16463cb58897fb165eb04a0bac46 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 04:01:15 +0000 Subject: [PATCH 083/169] remove some debugging statements --- tests/modeldiffs/criteo1tb/compare.py | 5 ----- tests/modeldiffs/criteo1tb_resnet/compare.py | 6 ------ 2 files changed, 11 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 491cae4c5..8c5881a8e 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -16,12 +16,9 @@ def key_transform(k): - print('key transform: ') new_key = [] s_count = None - print(k) for i in k: - print(f'in transform: {i}') if 'Sequential' in i: s_count = int(i.split('_')[1]) continue @@ -34,8 +31,6 @@ def key_transform(k): elif 'weight' in i: i = i.replace('weight', 'kernel') new_key.append(i) - print(f'out transform: {i}') - print(f'new key {new_key}') return tuple(new_key) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 0253f65bc..92bd278b0 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -16,13 +16,10 @@ def key_transform(k): - print('key transform: ') new_key = [] s_count = None resnet_count = None - print(k) for i in k: - print(f'in transform {i}') if 'Sequential' in i: s_count = int(i.split('_')[1]) continue @@ -35,13 +32,10 @@ def key_transform(k): if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') - print(resnet_count) i = name + '_' + str(s_count * 3 + int(resnet_count)) elif 'weight' in i: i = i.replace('weight', 'kernel') - print(f'out transform {i}') new_key.append(i) - print(f'new key {new_key}') return tuple(new_key) From b1c35d4623dd9657bf96a42bb2052d139a160972 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 18 Nov 2023 04:08:39 +0000 Subject: [PATCH 084/169] add debugging statement --- tests/modeldiffs/torch2jax_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 07f7cc360..9600cd204 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -106,6 +106,7 @@ def diff(self): if s_p == s_j: count += 1 else: + print('Difference in pytorch and jax key:') print(k, s_p, s_j) print(f'Number of values with identical shapes: {count}') From 09de0e69073134378b5c5203c46a99b12403c57a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Sat, 18 Nov 2023 00:00:41 -0800 Subject: [PATCH 085/169] fix mnist weights bug --- .../workloads/mnist/mnist_pytorch/workload.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index b7f33b94b..e638df078 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -80,9 +80,7 @@ def _build_input_queue( weights = torch.as_tensor( batch['weights'], dtype=torch.bool, device=DEVICE) else: - weights = torch.ones((batch['targets'].shape[-1],), - dtype=torch.bool, - device=DEVICE) + weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: dist.broadcast(inputs, src=0) From e7e52f0c5e266aff8c49a9c08d8f3457b2d3fc42 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Nov 2023 22:09:23 +0000 Subject: [PATCH 086/169] resnet fix --- .../criteo1tb/criteo1tb_pytorch/models.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 7f36789cd..4daebacdb 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -77,15 +77,19 @@ def __init__(self, self.register_parameter(f'embedding_chunk_{i}', chunk) self.embedding_table_chucks.append(chunk) - bottom_mlp_layers = [] input_dim = self.num_dense_features - for dense_dim in self.mlp_bottom_dims: - bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) - bottom_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) + bot_mlp_blocks = [] + for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): + block = [] + block.append(nn.Linear(input_dim, dense_dim)) + block.append(nn.ReLU(inplace=True)) + block = nn.Sequential(*block) + if layer_idx > 0: + block = ResNetBlock(block) + bot_mlp_blocks_append(block) input_dim = dense_dim - self.bot_mlp = nn.Sequential(*[ResNetBlock(layer) for layer in bottom_mlp_layers]) + self.bot_mlp = nn.Sequential(*bot_mlp_blocks) + for module in self.bot_mlp.modules(): if isinstance(module, nn.Linear): limit = math.sqrt(6. / (module.in_features + module.out_features)) @@ -95,27 +99,25 @@ def __init__(self, math.sqrt(1. / module.out_features)) self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) - # TODO: Write down the formula here instead of the constant. - input_dims = 634 - top_mlp_layers = [] + fan_in = 634 num_layers_top = len(self.mlp_top_dims) + mlp_top_blocks = [] for layer_idx, fan_out in enumerate(self.mlp_top_dims): - fan_in = input_dims if layer_idx == 0 \ - else self.mlp_top_dims[layer_idx - 1] - top_mlp_layers.append(nn.Linear(fan_in, fan_out)) + block = [] + block.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): - top_mlp_layers.append(nn.ReLU(inplace=True)) - if use_layer_norm: - top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) + block.append(nn.ReLU(inplace=True)) if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): - top_mlp_layers.append(nn.Dropout(p=dropout_rate)) - self.top_mlp = nn.Sequential(*[ResNetBlock(layer) for layer in top_mlp_layers]) - if use_layer_norm: - self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) - else: - self.embed_ln = None + block.append(nn.Dropout(p=dropout_rate)) + block = nn.Sequential(*block) + if layer_idx > 0: + block = ResNetBlock(block) + mlp_top_blocks.append(block) + fan_in = fan_out + self.top_mlp = nn.Sequential(*mlp_top_blocks) + for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( From 55d72d81feb9f129cd191029068cf74f1f5787e1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Nov 2023 22:12:57 +0000 Subject: [PATCH 087/169] fix --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 1 - .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index b2f11a597..c4265f201 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -56,7 +56,6 @@ def __call__(self, x, train): )( bot_mlp_input) bot_mlp_input += nn.relu(x) - base_init_fn = jnn.initializers.uniform(scale=1.0) # Embedding table init and lookup for a single unified table. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 4daebacdb..67fd21340 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -86,7 +86,7 @@ def __init__(self, block = nn.Sequential(*block) if layer_idx > 0: block = ResNetBlock(block) - bot_mlp_blocks_append(block) + bot_mlp_blocks.append(block) input_dim = dense_dim self.bot_mlp = nn.Sequential(*bot_mlp_blocks) From ca8a00ab414f340d1fa4bb6391218af478914c5a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Nov 2023 23:05:58 +0000 Subject: [PATCH 088/169] debugging --- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 92bd278b0..24f3d84b6 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -19,6 +19,8 @@ def key_transform(k): new_key = [] s_count = None resnet_count = None + print("key before") + print(k) for i in k: if 'Sequential' in i: s_count = int(i.split('_')[1]) @@ -36,6 +38,8 @@ def key_transform(k): elif 'weight' in i: i = i.replace('weight', 'kernel') new_key.append(i) + print("key after") + print(new_key) return tuple(new_key) From d929d7642cec6f10136d8087c78a02804fd67e89 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Nov 2023 23:25:27 +0000 Subject: [PATCH 089/169] compare_fix --- tests/modeldiffs/criteo1tb_resnet/compare.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 24f3d84b6..4db9e165d 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -17,24 +17,24 @@ def key_transform(k): new_key = [] - s_count = None + mlp_count = None + block_count = None resnet_count = None print("key before") print(k) for i in k: - if 'Sequential' in i: - s_count = int(i.split('_')[1]) - continue if 'Embedding' in i: return ('embedding_table',) - if 'ResNetBlock' in i: - name, count = i.split('_') - resnet_count = int(count) + if 'Sequential' in i: + if mlp_count is None: + mlp_count = int(i.split('_')[1]) + else: + block_count = int(i.split(_)[1]) continue if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') - i = name + '_' + str(s_count * 3 + int(resnet_count)) + i = name + '_' + str(mlp_count * 3 + int(block_count)) elif 'weight' in i: i = i.replace('weight', 'kernel') new_key.append(i) From 5192df778f1954fc74fec21fd653cddcc11b13c8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Nov 2023 23:27:35 +0000 Subject: [PATCH 090/169] fix --- tests/modeldiffs/criteo1tb_resnet/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 4db9e165d..591068400 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -29,7 +29,7 @@ def key_transform(k): if mlp_count is None: mlp_count = int(i.split('_')[1]) else: - block_count = int(i.split(_)[1]) + block_count = int(i.split('_')[1]) continue if 'Linear' in i: i = i.replace('Linear', 'Dense') From a88f51696aa61d200a101f0c2be8f6d40375eae8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Nov 2023 23:31:32 +0000 Subject: [PATCH 091/169] fix --- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 591068400..d06759992 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -31,13 +31,15 @@ def key_transform(k): else: block_count = int(i.split('_')[1]) continue + if 'ResNetBlock' in i: + continue if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') i = name + '_' + str(mlp_count * 3 + int(block_count)) elif 'weight' in i: i = i.replace('weight', 'kernel') - new_key.append(i) + new_key.append(i) print("key after") print(new_key) return tuple(new_key) From 7aca234eec3a50922baff652550ecc54565f6d9e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 20 Nov 2023 23:42:37 +0000 Subject: [PATCH 092/169] fix --- tests/modeldiffs/criteo1tb_resnet/compare.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index d06759992..60d7f7c7b 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -18,8 +18,7 @@ def key_transform(k): new_key = [] mlp_count = None - block_count = None - resnet_count = None + block_count = 0 print("key before") print(k) for i in k: @@ -28,15 +27,14 @@ def key_transform(k): if 'Sequential' in i: if mlp_count is None: mlp_count = int(i.split('_')[1]) - else: - block_count = int(i.split('_')[1]) continue if 'ResNetBlock' in i: + block_count = int(i.split('_')[1]) + 1 continue if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') - i = name + '_' + str(mlp_count * 3 + int(block_count)) + i = name + '_' + str(mlp_count * 3 + block_count) elif 'weight' in i: i = i.replace('weight', 'kernel') new_key.append(i) From c2d288eafcc57103b5204c369f3f331136ebf865 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 00:05:06 +0000 Subject: [PATCH 093/169] block count --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 2 -- tests/modeldiffs/criteo1tb_resnet/compare.py | 10 ++++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 67fd21340..9f0db8f17 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -144,8 +144,6 @@ def forward(self, x): embedded_sparse = embedding_table[idx_lookup] embedded_sparse = torch.reshape(embedded_sparse, [batch_size, -1, self.embed_dim]) - if self.embed_ln: - embedded_sparse = self.embed_ln(embedded_sparse) # Dot product interactions. concatenated_dense = self.dot_interact( dense_features=embedded_dense, sparse_features=embedded_sparse) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 60d7f7c7b..f6baf19cb 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -18,7 +18,8 @@ def key_transform(k): new_key = [] mlp_count = None - block_count = 0 + resnet_block_count = None + mlp_block_count = None print("key before") print(k) for i in k: @@ -27,13 +28,18 @@ def key_transform(k): if 'Sequential' in i: if mlp_count is None: mlp_count = int(i.split('_')[1]) + else: + mlp_block_count = int(i.split('_')[1]) continue if 'ResNetBlock' in i: - block_count = int(i.split('_')[1]) + 1 + # off set resnet block count by 1 + # since first mlp layer has no resnet connection + resnet_block_count = int(i.split('_')[1]) + 1 continue if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') + block_count = max([mlp_block_count, resnet_block_count]) i = name + '_' + str(mlp_count * 3 + block_count) elif 'weight' in i: i = i.replace('weight', 'kernel') From 1bb484f710186a28ef2dca7fb6f2d6b8c1c4463e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 00:06:36 +0000 Subject: [PATCH 094/169] fix --- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index f6baf19cb..f5521a514 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -18,8 +18,8 @@ def key_transform(k): new_key = [] mlp_count = None - resnet_block_count = None - mlp_block_count = None + resnet_block_count = 0 + mlp_block_count = 0 print("key before") print(k) for i in k: From b50e9ddf2206ff426cc00b993b677d590cbfa46e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 00:29:11 +0000 Subject: [PATCH 095/169] fix resnet jax --- .../criteo1tb/criteo1tb_jax/models.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index c4265f201..5304b2140 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -36,7 +36,7 @@ def __call__(self, x, train): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - # bottom mlp + # Bottom MLP mlp_bottom_dims = self.mlp_bottom_dims bot_mlp_input = nn.Dense( @@ -56,28 +56,31 @@ def __call__(self, x, train): )( bot_mlp_input) bot_mlp_input += nn.relu(x) - base_init_fn = jnn.initializers.uniform(scale=1.0) + + bot_mlp_output = bot_mlp_input + batch_size = bot_mlp_output.shape[0] + feature_stack = jnp.reshape(bot_mlp_output, + [batch_size, -1, self.embed_dim]) # Embedding table init and lookup for a single unified table. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size def scaled_init(key, shape, dtype=jnp.float_): - return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) + return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / + jnp.sqrt(self.vocab_size)) embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) + idx_lookup = jnp.reshape(idx_lookup, [-1]) embed_features = embedding_table[idx_lookup] - batch_size = bot_mlp_input.shape[0] - print('embed_features shape') - print(jnp.shape(embed_features)) - print(jnp.shape(bot_mlp_input)) embed_features = jnp.reshape(embed_features, - (batch_size, 26 * self.embed_dim)) - top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) + [batch_size, -1, self.embed_dim]) + feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) + dot_interact_output = dot_interact(concat_features=feature_stack) + top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], + axis=-1) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims - num_layers_top = len(mlp_top_dims) - print(jnp.shape(top_mlp_input)) top_mlp_input = nn.Dense( mlp_top_dims[0], kernel_init=jnn.initializers.normal( From c83bdad1102173953566dd3a4c66956ebfee9527 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 00:46:03 +0000 Subject: [PATCH 096/169] remove debugging statemetns --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 9 +++------ .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 1 - tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ---- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index a6e4a43d4..d3d908b5c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -74,7 +74,9 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + tabulate: Optional[bool] = False, + ) -> spec.ModelInitState: """Only dropout is used.""" del aux_dropout_rate if self.use_resnet: @@ -104,11 +106,6 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), - console_kwargs={'force_terminal': False, - 'force_jupyter': False, - 'width': 240},) - print(tabulate_fn(fake_inputs, train=False)) return jax_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 7957c4869..3671572ef 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -95,7 +95,6 @@ def init_model_fn( model = DDP(model, device_ids=[RANK], output_device=RANK) else: model = torch.nn.DataParallel(model) - print(model) return model, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index f5521a514..ba578ba69 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -20,8 +20,6 @@ def key_transform(k): mlp_count = None resnet_block_count = 0 mlp_block_count = 0 - print("key before") - print(k) for i in k: if 'Embedding' in i: return ('embedding_table',) @@ -44,8 +42,6 @@ def key_transform(k): elif 'weight' in i: i = i.replace('weight', 'kernel') new_key.append(i) - print("key after") - print(new_key) return tuple(new_key) From a6f2ba07a6bc6862d6c4c29157588e257a7dba10 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 00:50:23 +0000 Subject: [PATCH 097/169] fix logging --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 5 ----- tests/modeldiffs/diff.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 5304b2140..04a7c485d 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -28,11 +28,6 @@ class DLRMResNet(nn.Module): @nn.compact def __call__(self, x, train): - print(self.vocab_size) - print(self.num_dense_features) - print(self.mlp_bottom_dims) - print(self.mlp_top_dims) - print(self.embed_dim) bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index a877e28c2..572b87f25 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -1,4 +1,4 @@ -from absl import logging +import logging from flax import jax_utils import jax import numpy as np From b641ab9acfd35cb17c1ffacca44b493a58798e6e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 00:54:04 +0000 Subject: [PATCH 098/169] add back print statemetns --- tests/modeldiffs/diff.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index 572b87f25..d56115258 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -60,5 +60,5 @@ def out_diff(jax_workload, max_diff = np.abs(out_p.detach().numpy() - np.array(out_j)).max() min_diff = np.abs(out_p.detach().numpy() - np.array(out_j)).min() - logging.info(f'Max fprop difference between jax and pytorch: {max_diff}') - logging.info(f'Min fprop difference between jax and pytorch: {min_diff}') + print(f'Max fprop difference between jax and pytorch: {max_diff}') + print(f'Min fprop difference between jax and pytorch: {min_diff}') From 62a43c8ea7fe6e8ac2100e71725cc11cbddaba9d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 01:33:16 +0000 Subject: [PATCH 099/169] resnet fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 9f0db8f17..496028543 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -99,6 +99,7 @@ def __init__(self, math.sqrt(1. / module.out_features)) self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + # TODO: Write down the formula here instead of the constant. fan_in = 634 num_layers_top = len(self.mlp_top_dims) @@ -112,7 +113,7 @@ def __init__(self, layer_idx == num_layers_top - 2): block.append(nn.Dropout(p=dropout_rate)) block = nn.Sequential(*block) - if layer_idx > 0: + if (layer_idx != 0) and (layer_idx != num_layers_top - 1): block = ResNetBlock(block) mlp_top_blocks.append(block) fan_in = fan_out @@ -196,8 +197,8 @@ def __init__(self, self.register_parameter(f'embedding_chunk_{i}', chunk) self.embedding_table_chucks.append(chunk) - bottom_mlp_layers = [] input_dim = self.num_dense_features + bottom_mlp_layers = [] for dense_dim in self.mlp_bottom_dims: bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) bottom_mlp_layers.append(nn.ReLU(inplace=True)) @@ -206,6 +207,7 @@ def __init__(self, input_dim = dense_dim self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): + print(module) if isinstance(module, nn.Linear): limit = math.sqrt(6. / (module.in_features + module.out_features)) nn.init.uniform_(module.weight.data, -limit, limit) @@ -217,8 +219,8 @@ def __init__(self, # TODO: Write down the formula here instead of the constant. input_dims = 506 - top_mlp_layers = [] num_layers_top = len(self.mlp_top_dims) + top_mlp_layers = [] for layer_idx, fan_out in enumerate(self.mlp_top_dims): fan_in = input_dims if layer_idx == 0 \ else self.mlp_top_dims[layer_idx - 1] From 84463e2792a85726cba67a61674b71b233ce150c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 01:45:54 +0000 Subject: [PATCH 100/169] change block structures --- .../criteo1tb/criteo1tb_pytorch/models.py | 21 ++++++++++++------- tests/modeldiffs/criteo1tb_resnet/compare.py | 10 ++++----- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 496028543..8ef0cdb2e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -6,15 +6,18 @@ from torch import nn -class ResNetBlock(nn.Module): - """Resnet block""""" - def __init__(self, module): +class DenseBlock(nn.Module): + """Dense block with optional residual connection.""""" + def __init__(self, module, resnet=False): super().__init__() self.module = module + self.resnet = resnet def forward(self, x): - return self.module(x) + x - + if self.resnet: + return self.module(x) + x + else: + return self.module(x) class DotInteract(nn.Module): """Performs feature interaction operation between dense or sparse features.""" @@ -85,7 +88,9 @@ def __init__(self, block.append(nn.ReLU(inplace=True)) block = nn.Sequential(*block) if layer_idx > 0: - block = ResNetBlock(block) + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) bot_mlp_blocks.append(block) input_dim = dense_dim self.bot_mlp = nn.Sequential(*bot_mlp_blocks) @@ -114,7 +119,9 @@ def __init__(self, block.append(nn.Dropout(p=dropout_rate)) block = nn.Sequential(*block) if (layer_idx != 0) and (layer_idx != num_layers_top - 1): - block = ResNetBlock(block) + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) mlp_top_blocks.append(block) fan_in = fan_out self.top_mlp = nn.Sequential(*mlp_top_blocks) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index ba578ba69..98d954780 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -18,8 +18,8 @@ def key_transform(k): new_key = [] mlp_count = None - resnet_block_count = 0 - mlp_block_count = 0 + resnet_block_count = None + mlp_block_count = None for i in k: if 'Embedding' in i: return ('embedding_table',) @@ -29,15 +29,15 @@ def key_transform(k): else: mlp_block_count = int(i.split('_')[1]) continue - if 'ResNetBlock' in i: + if 'DenseBlock' in i: # off set resnet block count by 1 # since first mlp layer has no resnet connection - resnet_block_count = int(i.split('_')[1]) + 1 + resnet_block_count = int(i.split('_')[1]) continue if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') - block_count = max([mlp_block_count, resnet_block_count]) + block_count = mlp_block_count if mlp_block_count else dense_block_count i = name + '_' + str(mlp_count * 3 + block_count) elif 'weight' in i: i = i.replace('weight', 'kernel') From 638536566144edb90dd0acb73281b318c083b57c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 01:47:15 +0000 Subject: [PATCH 101/169] fix --- tests/modeldiffs/criteo1tb_resnet/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 98d954780..37d63a519 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -37,7 +37,7 @@ def key_transform(k): if 'Linear' in i: i = i.replace('Linear', 'Dense') name, count = i.split('_') - block_count = mlp_block_count if mlp_block_count else dense_block_count + block_count = mlp_block_count if mlp_block_count else resnet_block_count i = name + '_' + str(mlp_count * 3 + block_count) elif 'weight' in i: i = i.replace('weight', 'kernel') From 441df6e881e56d189e2777bada3baea44ba437b1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 19:54:30 +0000 Subject: [PATCH 102/169] add treshold submissions --- .../threshold_submissions/README.md | 0 .../external_tuning/jax_nadamw_full_budget.py | 345 ++++++++++++++++++ .../jax_nadamw_target_setting.py | 171 +++++++++ .../pytorch_nadamw_full_budget.py | 212 +++++++++++ .../pytorch_nadamw_target_setting.py | 171 +++++++++ .../self_tuning/jax_nadamw_full_budget.py | 345 ++++++++++++++++++ .../self_tuning/jax_nadamw_target_setting.py | 171 +++++++++ .../self_tuning/pytorch_nadamw_full_budget.py | 212 +++++++++++ .../pytorch_nadamw_target_setting.py | 212 +++++++++++ 9 files changed, 1839 insertions(+) create mode 100644 reference_algorithms/threshold_submissions/README.md create mode 100644 reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py create mode 100644 reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py create mode 100644 reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py create mode 100644 reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py create mode 100644 reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py create mode 100644 reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py create mode 100644 reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py create mode 100644 reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/README.md b/reference_algorithms/threshold_submissions/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..099613fcf --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..21f2a7b2b --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,171 @@ +"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +from flax import jax_utils +import jax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import \ + update_params # pylint: disable=unused-import + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: + learning_rate: this is a fixed global scaling factor. + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + weight_decay: strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: a tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this) + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + power: the power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + target_setting_step_hint = int(0.75 * workload.step_hint) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, + hyperparameters) + + # Create optimizer. + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..71b819e66 --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,212 @@ +"""Submission file for a NAdamW optimizer in PyTorch.""" + +import math +from typing import List + +import torch +from torch import Tensor + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ + update_params # pylint: disable=unused-import + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay), + } + + target_setting_step_hint = int(0.75 * workload.step_hint) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..21f2a7b2b --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,171 @@ +"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +from flax import jax_utils +import jax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import \ + update_params # pylint: disable=unused-import + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: + learning_rate: this is a fixed global scaling factor. + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + weight_decay: strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: a tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this) + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + power: the power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + target_setting_step_hint = int(0.75 * workload.step_hint) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, + hyperparameters) + + # Create optimizer. + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..099613fcf --- /dev/null +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..21f2a7b2b --- /dev/null +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,171 @@ +"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +from flax import jax_utils +import jax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import \ + update_params # pylint: disable=unused-import + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: + learning_rate: this is a fixed global scaling factor. + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + weight_decay: strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: a tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this) + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + power: the power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + target_setting_step_hint = int(0.75 * workload.step_hint) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, + hyperparameters) + + # Create optimizer. + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..71b819e66 --- /dev/null +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,212 @@ +"""Submission file for a NAdamW optimizer in PyTorch.""" + +import math +from typing import List + +import torch +from torch import Tensor + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ + update_params # pylint: disable=unused-import + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay), + } + + target_setting_step_hint = int(0.75 * workload.step_hint) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..71b819e66 --- /dev/null +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,212 @@ +"""Submission file for a NAdamW optimizer in PyTorch.""" + +import math +from typing import List + +import torch +from torch import Tensor + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ + update_params # pylint: disable=unused-import + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay), + } + + target_setting_step_hint = int(0.75 * workload.step_hint) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state From 94b360fe85425025138e94825fd07c1d4b4384b3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 21:04:25 +0000 Subject: [PATCH 103/169] add tuning search space --- .../jax_nadamw_target_setting.py | 2 +- .../pytorch_nadamw_full_budget.py | 183 ++++++++-- .../pytorch_nadamw_target_setting.py | 333 ++++++++++-------- .../self_tuning/jax_nadamw_target_setting.py | 2 +- .../self_tuning/pytorch_nadamw_full_budget.py | 183 ++++++++-- .../pytorch_nadamw_target_setting.py | 2 +- 6 files changed, 508 insertions(+), 197 deletions(-) diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py index 21f2a7b2b..8f20bcbc6 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py @@ -162,7 +162,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, + b1=1 - hyperparameters.one_minus_beta1, b2=hyperparameters.beta2, eps=epsilon, weight_decay=hyperparameters.weight_decay) diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py index 71b819e66..01cffc52e 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py @@ -1,29 +1,32 @@ -"""Submission file for a NAdamW optimizer in PyTorch.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import List +from typing import Dict, Iterator, List, Tuple +from absl import logging import torch from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from algorithmic_efficiency.pytorch_utils import pytorch_setup +USE_PYTORCH_DDP = pytorch_setup()[0] -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of the NAdam algorithm (there is also a comment in the code which highlights the only difference of NAdamW and AdamW). For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -37,7 +40,7 @@ class NAdamW(torch.optim.Optimizer): https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ - """ + """ def __init__(self, params, @@ -72,10 +75,11 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ + """ self._cuda_graph_capture_health_check() loss = None @@ -139,10 +143,10 @@ def nadamw(params: List[Tensor], beta2: float, lr: float, weight_decay: float, - eps: float): + eps: float) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. - """ + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( @@ -155,13 +159,13 @@ def nadamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] - # update step + # Update step. step_t += 1 - # Perform stepweight decay + # Perform stepweight decay. param.mul_(1 - lr * weight_decay) - # Decay the first and second moment running average coefficient + # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -194,19 +198,150 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) optimizer_state = { 'optimizer': NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, weight_decay=hyperparameters.weight_decay), } - target_setting_step_hint = int(0.75 * workload.step_hint) - optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py index 21f2a7b2b..7aa8160a4 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py @@ -1,12 +1,10 @@ -"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" +"""Submission file for a NAdamW optimizer in PyTorch.""" -from typing import Any, Callable, NamedTuple, Optional, Union +import math +from typing import List -import chex -from flax import jax_utils -import jax -import jax.numpy as jnp -import optax +import torch +from torch import Tensor from algorithmic_efficiency import spec from reference_algorithms.target_setting_algorithms import cosine_warmup @@ -14,131 +12,177 @@ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ update_params # pylint: disable=unused-import -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - Args: - learning_rate: this is a fixed global scaling factor. - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - weight_decay: strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: a tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this) - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - Args: - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - power: the power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) def init_optimizer_state(workload: spec.Workload, @@ -147,25 +191,22 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters: spec.Hyperparameters, rng: spec.RandomState) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params del model_state del rng - target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) - - # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=epsilon, - weight_decay=hyperparameters.weight_decay) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay), + } + + target_setting_step_hint = int(0.75 * workload.step_hint) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py index 21f2a7b2b..8f20bcbc6 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py @@ -162,7 +162,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, + b1=1 - hyperparameters.one_minus_beta1, b2=hyperparameters.beta2, eps=epsilon, weight_decay=hyperparameters.weight_decay) diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py index 71b819e66..01cffc52e 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py @@ -1,29 +1,32 @@ -"""Submission file for a NAdamW optimizer in PyTorch.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import List +from typing import Dict, Iterator, List, Tuple +from absl import logging import torch from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from algorithmic_efficiency.pytorch_utils import pytorch_setup +USE_PYTORCH_DDP = pytorch_setup()[0] -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of the NAdam algorithm (there is also a comment in the code which highlights the only difference of NAdamW and AdamW). For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -37,7 +40,7 @@ class NAdamW(torch.optim.Optimizer): https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ - """ + """ def __init__(self, params, @@ -72,10 +75,11 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ + """ self._cuda_graph_capture_health_check() loss = None @@ -139,10 +143,10 @@ def nadamw(params: List[Tensor], beta2: float, lr: float, weight_decay: float, - eps: float): + eps: float) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. - """ + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( @@ -155,13 +159,13 @@ def nadamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] - # update step + # Update step. step_t += 1 - # Perform stepweight decay + # Perform stepweight decay. param.mul_(1 - lr * weight_decay) - # Decay the first and second moment running average coefficient + # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -194,19 +198,150 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) optimizer_state = { 'optimizer': NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, weight_decay=hyperparameters.weight_decay), } - target_setting_step_hint = int(0.75 * workload.step_hint) - optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py index 71b819e66..7aa8160a4 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py @@ -201,7 +201,7 @@ def init_optimizer_state(workload: spec.Workload, NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), + betas=(1 - hyperparameters.one_minus_beta1, hyperparameters.beta2), eps=epsilon, weight_decay=hyperparameters.weight_decay), } From 05427690c7b69c10b62a21c6b0cc43dd43dc0fb8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:16:14 +0000 Subject: [PATCH 104/169] add self tuning threshold submission --- .../threshold_submissions/README.md | 80 ++++++ .../jax_nadamw_target_setting.py | 242 ++++++++++++++--- .../pytorch_nadamw_target_setting.py | 183 +++++++++++-- .../self_tuning/jax_nadamw_full_budget.py | 14 + .../self_tuning/jax_nadamw_target_setting.py | 256 +++++++++++++++--- .../self_tuning/pytorch_nadamw_full_budget.py | 14 + .../pytorch_nadamw_target_setting.py | 197 ++++++++++++-- 7 files changed, 870 insertions(+), 116 deletions(-) diff --git a/reference_algorithms/threshold_submissions/README.md b/reference_algorithms/threshold_submissions/README.md index e69de29bb..eb8995408 100644 --- a/reference_algorithms/threshold_submissions/README.md +++ b/reference_algorithms/threshold_submissions/README.md @@ -0,0 +1,80 @@ +# Threshold Submissions + +## Externally Tuned Ruleset + +### JAX + +The threshold submissions for jax are: +- `reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py` + +Example command: + +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json +``` + +### PyTorch + +The threshold submissions for PyTorch are +- `reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py` + + +Example command: + +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json +``` + +## Self-tuning Ruleset + +### JAX + +The threshold submissions for jax are +- `reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py` + +Example command: +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py \ + --tuning_ruleset=self +``` + +### PyTorch + +The threshold submissions for PyTorch are +- `reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py` + +Example command: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_ruleset=self +``` \ No newline at end of file diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py index 8f20bcbc6..ef0c11c0d 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py @@ -1,21 +1,30 @@ -"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" -from typing import Any, Callable, NamedTuple, Optional, Union +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on import chex from flax import jax_utils import jax +from jax import lax import jax.numpy as jnp import optax from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import + +_GRAD_CLIP_EPS = 1e-6 # Forked from @@ -32,6 +41,7 @@ def nadamw( Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. + References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch @@ -39,24 +49,26 @@ def nadamw( Current code implements a simpler version with no momentum decay and slightly different bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: - learning_rate: this is a fixed global scaling factor. - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - weight_decay: strength of the weight decay regularization. Note that this + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the "schedule multiplier", but not the base learning rate. - weight_decay_mask: a tree with same structure as (or a prefix of) the params + weight_decay_mask: A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the weight decay to, and `False` for those you want to skip. Note that the Nadam gradient transformations are applied to all parameters. + Returns: An (init_fn, update_fn) tuple. """ @@ -75,21 +87,24 @@ def scale_by_nadam(b1: float = 0.9, debias: bool = True, power: float = 0.5) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. + References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this) + follows this). + Current code implements a simpler version with no momentum decay and slightly different (standard Adam) bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - power: the power to use in the preconditioner (0.5 in default adam). + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). Returns: An (init_fn, update_fn) tuple. """ @@ -151,21 +166,180 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn - # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, - b1=1 - hyperparameters.one_minus_beta1, + b1=1.0 - hyperparameters.one_minus_beta1, b2=hyperparameters.beta2, - eps=epsilon, + eps=1e-8, weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py index 7aa8160a4..530dd3acf 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py @@ -1,29 +1,32 @@ -"""Submission file for a NAdamW optimizer in PyTorch.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import List +from typing import Dict, Iterator, List, Tuple +from absl import logging import torch from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from algorithmic_efficiency.pytorch_utils import pytorch_setup +USE_PYTORCH_DDP = pytorch_setup()[0] -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of the NAdam algorithm (there is also a comment in the code which highlights the only difference of NAdamW and AdamW). For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -37,7 +40,7 @@ class NAdamW(torch.optim.Optimizer): https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ - """ + """ def __init__(self, params, @@ -72,10 +75,11 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ + """ self._cuda_graph_capture_health_check() loss = None @@ -139,10 +143,10 @@ def nadamw(params: List[Tensor], beta2: float, lr: float, weight_decay: float, - eps: float): + eps: float) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. - """ + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( @@ -155,13 +159,13 @@ def nadamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] - # update step + # Update step. step_t += 1 - # Perform stepweight decay + # Perform stepweight decay. param.mul_(1 - lr * weight_decay) - # Decay the first and second moment running average coefficient + # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -194,19 +198,150 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) optimizer_state = { 'optimizer': NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(1 - hyperparameters.one_minus_beta1, hyperparameters.beta2), - eps=epsilon, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, weight_decay=hyperparameters.weight_decay), } - target_setting_step_hint = int(0.75 * workload.step_hint) - optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py index 099613fcf..b35750086 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py @@ -26,6 +26,14 @@ _GRAD_CLIP_EPS = 1e-6 +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -165,7 +173,10 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng + del hyperparameters + hyperparameters=HPARAMS + def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) @@ -267,6 +278,9 @@ def update_params(workload: spec.Workload, del current_params_types del loss_type del eval_results + del hyperparameters + + hyperparameters = HPARAMS optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py index 8f20bcbc6..190720213 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py @@ -1,22 +1,39 @@ -"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" -from typing import Any, Callable, NamedTuple, Optional, Union +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on import chex from flax import jax_utils import jax +from jax import lax import jax.numpy as jnp import optax from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -32,6 +49,7 @@ def nadamw( Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. + References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch @@ -39,24 +57,26 @@ def nadamw( Current code implements a simpler version with no momentum decay and slightly different bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: - learning_rate: this is a fixed global scaling factor. - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - weight_decay: strength of the weight decay regularization. Note that this + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the "schedule multiplier", but not the base learning rate. - weight_decay_mask: a tree with same structure as (or a prefix of) the params + weight_decay_mask: A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the weight decay to, and `False` for those you want to skip. Note that the Nadam gradient transformations are applied to all parameters. + Returns: An (init_fn, update_fn) tuple. """ @@ -75,21 +95,24 @@ def scale_by_nadam(b1: float = 0.9, debias: bool = True, power: float = 0.5) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. + References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this) + follows this). + Current code implements a simpler version with no momentum decay and slightly different (standard Adam) bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - power: the power to use in the preconditioner (0.5 in default adam). + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). Returns: An (init_fn, update_fn) tuple. """ @@ -150,22 +173,187 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng + del hyperparameters - target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) + hyperparameters=HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn - # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint*0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, - b1=1 - hyperparameters.one_minus_beta1, + b1=1.0 - hyperparameters.one_minus_beta1, b2=hyperparameters.beta2, - eps=epsilon, + eps=1e-8, weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py index 01cffc52e..a1cf612f2 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py @@ -16,6 +16,14 @@ USE_PYTORCH_DDP = pytorch_setup()[0] +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): @@ -197,6 +205,9 @@ def init_optimizer_state(workload: spec.Workload, """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng + del hyperparameters + + hyperparameters = HPARAMS optimizer_state = { 'optimizer': @@ -239,7 +250,10 @@ def update_params(workload: spec.Workload, del current_params_types del loss_type del eval_results + del hyperparameters + hyperparameters = HPARAMS + current_model = current_param_container current_model.train() optimizer_state['optimizer'].zero_grad() diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py index 7aa8160a4..1209abadc 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py @@ -1,29 +1,40 @@ -"""Submission file for a NAdamW optimizer in PyTorch.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import List +from typing import Dict, Iterator, List, Tuple +from absl import logging import torch from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from algorithmic_efficiency.pytorch_utils import pytorch_setup +USE_PYTORCH_DDP = pytorch_setup()[0] -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of the NAdam algorithm (there is also a comment in the code which highlights the only difference of NAdamW and AdamW). For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -37,7 +48,7 @@ class NAdamW(torch.optim.Optimizer): https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ - """ + """ def __init__(self, params, @@ -72,10 +83,11 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ + """ self._cuda_graph_capture_health_check() loss = None @@ -139,10 +151,10 @@ def nadamw(params: List[Tensor], beta2: float, lr: float, weight_decay: float, - eps: float): + eps: float) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. - """ + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( @@ -155,13 +167,13 @@ def nadamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] - # update step + # Update step. step_t += 1 - # Perform stepweight decay + # Perform stepweight decay. param.mul_(1 - lr * weight_decay) - # Decay the first and second moment running average coefficient + # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -193,20 +205,157 @@ def init_optimizer_state(workload: spec.Workload, """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng + del hyperparameters + + hyperparameters = HPARAMS - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) optimizer_state = { 'optimizer': NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(1 - hyperparameters.one_minus_beta1, hyperparameters.beta2), - eps=epsilon, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, weight_decay=hyperparameters.weight_decay), } - target_setting_step_hint = int(0.75 * workload.step_hint) - optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint*0.75, hyperparameters, optimizer_state['optimizer']) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch From 2439135829803cb4130e1b18e6766dc93782a0f7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:22:37 +0000 Subject: [PATCH 105/169] update readme --- reference_algorithms/threshold_submissions/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/reference_algorithms/threshold_submissions/README.md b/reference_algorithms/threshold_submissions/README.md index eb8995408..d73706ad7 100644 --- a/reference_algorithms/threshold_submissions/README.md +++ b/reference_algorithms/threshold_submissions/README.md @@ -1,4 +1,5 @@ # Threshold Submissions +TODO: link back to section in rules. ## Externally Tuned Ruleset From cc0c6ff817e29e5bf40f2ae7708463357846ed52 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:23:05 +0000 Subject: [PATCH 106/169] tuning search space --- .../external_tuning/tuning_search_space.json | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json diff --git a/reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json b/reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json new file mode 100644 index 000000000..65562905a --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json @@ -0,0 +1,50 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.2, + "learning_rate": 0.0008445074561975979, + "one_minus_beta1": 0.11042418465, + "beta2": 0.9978504782314613, + "weight_decay": 0.08135402759553023, + "warmup_factor": 0.05 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.004958460849689891, + "one_minus_beta1": 0.13625575743, + "beta2": 0.6291854735396584, + "weight_decay": 0.1147386261512052, + "warmup_factor": 0.02 + }, + { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } +] + + + + + + From 0338f8fcfbce807aac5bde4fbf278f13a9eeb1b1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:26:37 +0000 Subject: [PATCH 107/169] add reference algorithms --- .../threshold_baselines/README.md | 81 +++++++++++++++++++ .../external_tuning/jax_nadamw_full_budget.py | 0 .../jax_nadamw_target_setting.py | 0 .../pytorch_nadamw_full_budget.py | 0 .../pytorch_nadamw_target_setting.py | 0 .../external_tuning/tuning_search_space.json | 0 .../self_tuning/jax_nadamw_full_budget.py | 0 .../self_tuning/jax_nadamw_target_setting.py | 0 .../self_tuning/pytorch_nadamw_full_budget.py | 0 .../pytorch_nadamw_target_setting.py | 0 .../threshold_submissions/README.md | 81 ------------------- 11 files changed, 81 insertions(+), 81 deletions(-) create mode 100644 reference_algorithms/threshold_baselines/README.md rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/jax_nadamw_full_budget.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/jax_nadamw_target_setting.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/pytorch_nadamw_full_budget.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/pytorch_nadamw_target_setting.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/tuning_search_space.json (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/self_tuning/jax_nadamw_full_budget.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/self_tuning/jax_nadamw_target_setting.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/self_tuning/pytorch_nadamw_full_budget.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/self_tuning/pytorch_nadamw_target_setting.py (100%) delete mode 100644 reference_algorithms/threshold_submissions/README.md diff --git a/reference_algorithms/threshold_baselines/README.md b/reference_algorithms/threshold_baselines/README.md new file mode 100644 index 000000000..fa0971997 --- /dev/null +++ b/reference_algorithms/threshold_baselines/README.md @@ -0,0 +1,81 @@ +# Threshold Baselines +TODO: link back to section in rules. + +## Externally Tuned Ruleset + +### JAX + +The threshold submissions for jax are: +- `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: + +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json +``` + +### PyTorch + +The threshold submissions for PyTorch are +- `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` + + +Example command: + +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json +``` + +## Self-tuning Ruleset + +### JAX + +The threshold submissions for jax are +- `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_ruleset=self +``` + +### PyTorch + +The threshold submissions for PyTorch are +- `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` + +Example command: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_ruleset=self +``` \ No newline at end of file diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py rename to reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py rename to reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py rename to reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py rename to reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json b/reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json rename to reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py rename to reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py rename to reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py rename to reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py rename to reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/README.md b/reference_algorithms/threshold_submissions/README.md deleted file mode 100644 index d73706ad7..000000000 --- a/reference_algorithms/threshold_submissions/README.md +++ /dev/null @@ -1,81 +0,0 @@ -# Threshold Submissions -TODO: link back to section in rules. - -## Externally Tuned Ruleset - -### JAX - -The threshold submissions for jax are: -- `reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py` - -Example command: - -```bash -python3 submission_runner.py \ - --framework=jax \ - --data_dir= \ - --experiment_dir= \ - --experiment_name= \ - --workload= \ - --submission_path=reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json -``` - -### PyTorch - -The threshold submissions for PyTorch are -- `reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py` - - -Example command: - -```bash -torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ - --framework=pytorch \ - --data_dir= \ - --experiment_dir= \ - --experiment_name=t \ - --workload=\ - --submission_path=reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json -``` - -## Self-tuning Ruleset - -### JAX - -The threshold submissions for jax are -- `reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py` - -Example command: -```bash -python3 submission_runner.py \ - --framework=jax \ - --data_dir= \ - --experiment_dir= \ - --experiment_name= \ - --workload= \ - --submission_path=reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py \ - --tuning_ruleset=self -``` - -### PyTorch - -The threshold submissions for PyTorch are -- `reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py` - -Example command: -```bash -torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ - --framework=pytorch \ - --data_dir= \ - --experiment_dir= \ - --experiment_name=t \ - --workload=\ - --submission_path=reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_ruleset=self -``` \ No newline at end of file From 5f834047e249b0419c0d2082b6737a1e3cd5c4a3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:40:25 +0000 Subject: [PATCH 108/169] formatting --- .../self_tuning/jax_nadamw_full_budget.py | 19 +++++++++-------- .../self_tuning/jax_nadamw_target_setting.py | 21 ++++++++++--------- .../self_tuning/pytorch_nadamw_full_budget.py | 19 +++++++++-------- .../pytorch_nadamw_target_setting.py | 21 ++++++++++--------- 4 files changed, 42 insertions(+), 38 deletions(-) diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py index b35750086..c54202e56 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py @@ -27,13 +27,14 @@ _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -175,8 +176,8 @@ def init_optimizer_state(workload: spec.Workload, del rng del hyperparameters - hyperparameters=HPARAMS - + hyperparameters = HPARAMS + def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py index 190720213..dd42743e2 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py @@ -27,13 +27,14 @@ _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -175,8 +176,8 @@ def init_optimizer_state(workload: spec.Workload, del rng del hyperparameters - hyperparameters=HPARAMS - + hyperparameters = HPARAMS + def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) @@ -192,7 +193,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): return schedule_fn # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint*0.75, hyperparameters) + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, b1=1.0 - hyperparameters.one_minus_beta1, diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py index a1cf612f2..57da48167 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -16,14 +16,15 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): @@ -253,7 +254,7 @@ def update_params(workload: spec.Workload, del hyperparameters hyperparameters = HPARAMS - + current_model = current_param_container current_model.train() optimizer_state['optimizer'].zero_grad() diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1209abadc..ef6e84c94 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -16,14 +16,15 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): @@ -230,7 +231,7 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint*0.75, hyperparameters, optimizer_state['optimizer']) + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) return optimizer_state @@ -253,7 +254,7 @@ def update_params(workload: spec.Workload, del hyperparameters hyperparameters = HPARAMS - + current_model = current_param_container current_model.train() optimizer_state['optimizer'].zero_grad() From 811b7c4dab634424fdfdebd28db48850f7eab1ec Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:42:42 +0000 Subject: [PATCH 109/169] baselines --- reference_algorithms/threshold_baselines/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/reference_algorithms/threshold_baselines/README.md b/reference_algorithms/threshold_baselines/README.md index fa0971997..09eed8f41 100644 --- a/reference_algorithms/threshold_baselines/README.md +++ b/reference_algorithms/threshold_baselines/README.md @@ -5,7 +5,7 @@ TODO: link back to section in rules. ### JAX -The threshold submissions for jax are: +The threshold baseline submissions for jax are: - `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` - `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` @@ -24,7 +24,7 @@ python3 submission_runner.py \ ### PyTorch -The threshold submissions for PyTorch are +The threshold baseline submissionss for PyTorch are: - `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` - `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` @@ -46,7 +46,7 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc ### JAX -The threshold submissions for jax are +The threshold baseline submissionss for jax are: - `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` - `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` @@ -64,7 +64,7 @@ python3 submission_runner.py \ ### PyTorch -The threshold submissions for PyTorch are +The threshold baseline submissionss for PyTorch are: - `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` - `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` From 30a0654963748b634f9f3a46b5559bfb3536ae77 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 23:57:20 +0000 Subject: [PATCH 110/169] rename prize qualification baselines --- .../threshold_baselines/README.md | 81 ---- .../external_tuning/jax_nadamw_full_budget.py | 345 ----------------- .../jax_nadamw_target_setting.py | 345 ----------------- .../pytorch_nadamw_full_budget.py | 347 ----------------- .../pytorch_nadamw_target_setting.py | 347 ----------------- .../external_tuning/tuning_search_space.json | 50 --- .../self_tuning/jax_nadamw_full_budget.py | 360 ----------------- .../self_tuning/jax_nadamw_target_setting.py | 360 ----------------- .../self_tuning/pytorch_nadamw_full_budget.py | 362 ------------------ .../pytorch_nadamw_target_setting.py | 362 ------------------ 10 files changed, 2959 deletions(-) delete mode 100644 reference_algorithms/threshold_baselines/README.md delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json delete mode 100644 reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py delete mode 100644 reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py delete mode 100644 reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py delete mode 100644 reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_baselines/README.md b/reference_algorithms/threshold_baselines/README.md deleted file mode 100644 index 09eed8f41..000000000 --- a/reference_algorithms/threshold_baselines/README.md +++ /dev/null @@ -1,81 +0,0 @@ -# Threshold Baselines -TODO: link back to section in rules. - -## Externally Tuned Ruleset - -### JAX - -The threshold baseline submissions for jax are: -- `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` - -Example command: - -```bash -python3 submission_runner.py \ - --framework=jax \ - --data_dir= \ - --experiment_dir= \ - --experiment_name= \ - --workload= \ - --submission_path=reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json -``` - -### PyTorch - -The threshold baseline submissionss for PyTorch are: -- `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` - - -Example command: - -```bash -torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ - --framework=pytorch \ - --data_dir= \ - --experiment_dir= \ - --experiment_name=t \ - --workload=\ - --submission_path=reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json -``` - -## Self-tuning Ruleset - -### JAX - -The threshold baseline submissionss for jax are: -- `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` - -Example command: -```bash -python3 submission_runner.py \ - --framework=jax \ - --data_dir= \ - --experiment_dir= \ - --experiment_name= \ - --workload= \ - --submission_path=reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py \ - --tuning_ruleset=self -``` - -### PyTorch - -The threshold baseline submissionss for PyTorch are: -- `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` - -Example command: -```bash -torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ - --framework=pytorch \ - --data_dir= \ - --experiment_dir= \ - --experiment_name=t \ - --workload=\ - --submission_path=reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_ruleset=self -``` \ No newline at end of file diff --git a/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py deleted file mode 100644 index 099613fcf..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py +++ /dev/null @@ -1,345 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on - -import chex -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - - Args: - learning_rate: A fixed global scaling factor. - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this). - - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - power: The power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py deleted file mode 100644 index ef0c11c0d..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py +++ /dev/null @@ -1,345 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on - -import chex -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - - Args: - learning_rate: A fixed global scaling factor. - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this). - - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - power: The power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py deleted file mode 100644 index 01cffc52e..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ /dev/null @@ -1,347 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Dict, Iterator, List, Tuple - -from absl import logging -import torch -from torch import Tensor -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. -class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay - } - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: - r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Perform stepweight decay. - param.mul_(1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Only difference between NAdamW and AdamW in this implementation. - # The official PyTorch implementation of NAdam uses a different algorithm. - # We undo these ops later on, which could cause numerical issues but saves - # us from having to make an extra copy of the gradients. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_state - del rng - - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py deleted file mode 100644 index 530dd3acf..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ /dev/null @@ -1,347 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Dict, Iterator, List, Tuple - -from absl import logging -import torch -from torch import Tensor -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. -class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay - } - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: - r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Perform stepweight decay. - param.mul_(1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Only difference between NAdamW and AdamW in this implementation. - # The official PyTorch implementation of NAdam uses a different algorithm. - # We undo these ops later on, which could cause numerical issues but saves - # us from having to make an extra copy of the gradients. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_state - del rng - - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json b/reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json deleted file mode 100644 index 65562905a..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json +++ /dev/null @@ -1,50 +0,0 @@ -[ - { - "dropout_rate": 0.0, - "label_smoothing": 0.1, - "learning_rate": 0.001308209823469072, - "one_minus_beta1": 0.02686663061, - "beta2": 0.9981232922116359, - "weight_decay": 0.16375311233774334, - "warmup_factor": 0.1 - }, - { - "dropout_rate": 0.0, - "label_smoothing": 0.2, - "learning_rate": 0.0008445074561975979, - "one_minus_beta1": 0.11042418465, - "beta2": 0.9978504782314613, - "weight_decay": 0.08135402759553023, - "warmup_factor": 0.05 - }, - { - "dropout_rate": 0.0, - "learning_rate": 0.001308209823469072, - "one_minus_beta1": 0.02686663061, - "beta2": 0.9981232922116359, - "weight_decay": 0.16375311233774334, - "warmup_factor": 0.1 - }, - { - "dropout_rate": 0.0, - "learning_rate": 0.004958460849689891, - "one_minus_beta1": 0.13625575743, - "beta2": 0.6291854735396584, - "weight_decay": 0.1147386261512052, - "warmup_factor": 0.02 - }, - { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } -] - - - - - - diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py deleted file mode 100644 index c54202e56..000000000 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py +++ /dev/null @@ -1,360 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on - -import chex -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 -} - - -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - - Args: - learning_rate: A fixed global scaling factor. - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this). - - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - power: The power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - del hyperparameters - - hyperparameters = HPARAMS - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del hyperparameters - - hyperparameters = HPARAMS - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py deleted file mode 100644 index dd42743e2..000000000 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py +++ /dev/null @@ -1,360 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on - -import chex -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 -} - - -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - - Args: - learning_rate: A fixed global scaling factor. - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this). - - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - power: The power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - del hyperparameters - - hyperparameters = HPARAMS - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del hyperparameters - - hyperparameters = HPARAMS - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py deleted file mode 100644 index 57da48167..000000000 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Dict, Iterator, List, Tuple - -from absl import logging -import torch -from torch import Tensor -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 -} - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. -class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay - } - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: - r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Perform stepweight decay. - param.mul_(1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Only difference between NAdamW and AdamW in this implementation. - # The official PyTorch implementation of NAdam uses a different algorithm. - # We undo these ops later on, which could cause numerical issues but saves - # us from having to make an extra copy of the gradients. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_state - del rng - del hyperparameters - - hyperparameters = HPARAMS - - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del hyperparameters - - hyperparameters = HPARAMS - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py deleted file mode 100644 index ef6e84c94..000000000 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Dict, Iterator, List, Tuple - -from absl import logging -import torch -from torch import Tensor -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 -} - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. -class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay - } - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: - r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Perform stepweight decay. - param.mul_(1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Only difference between NAdamW and AdamW in this implementation. - # The official PyTorch implementation of NAdam uses a different algorithm. - # We undo these ops later on, which could cause numerical issues but saves - # us from having to make an extra copy of the gradients. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_state - del rng - del hyperparameters - - hyperparameters = HPARAMS - - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del hyperparameters - - hyperparameters = HPARAMS - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch From b24e7e7284186cd852f4bb5fdb032148ca69659a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 23:57:37 +0000 Subject: [PATCH 111/169] rename --- .../prize_qualification_baselines/README.md | 83 ++++ .../external_tuning/jax_nadamw_full_budget.py | 345 +++++++++++++++++ .../jax_nadamw_target_setting.py | 345 +++++++++++++++++ .../pytorch_nadamw_full_budget.py | 347 +++++++++++++++++ .../pytorch_nadamw_target_setting.py | 347 +++++++++++++++++ .../external_tuning/tuning_search_space.json | 50 +++ .../self_tuning/jax_nadamw_full_budget.py | 360 +++++++++++++++++ .../self_tuning/jax_nadamw_target_setting.py | 360 +++++++++++++++++ .../self_tuning/pytorch_nadamw_full_budget.py | 362 ++++++++++++++++++ .../pytorch_nadamw_target_setting.py | 362 ++++++++++++++++++ 10 files changed, 2961 insertions(+) create mode 100644 reference_algorithms/prize_qualification_baselines/README.md create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json create mode 100644 reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py create mode 100644 reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py create mode 100644 reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py create mode 100644 reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/reference_algorithms/prize_qualification_baselines/README.md new file mode 100644 index 000000000..614f87b32 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/README.md @@ -0,0 +1,83 @@ +# Prize Qualification Baselines +This directory contains the baseine(s) that submissions that must beat to qualify for prizes. + +TODO: link back to section in rules. + +## Externally Tuned Ruleset + +### JAX + +The prize qualification baseline submissions for jax are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: + +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json +``` + +### PyTorch + +The prize qualification baseline submissionss for PyTorch are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` + + +Example command: + +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json +``` + +## Self-tuning Ruleset + +### JAX + +The prize qualification baseline submissionss for jax are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_ruleset=self +``` + +### PyTorch + +The prize qualification baseline submissionss for PyTorch are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` + +Example command: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_ruleset=self +``` \ No newline at end of file diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..099613fcf --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..ef0c11c0d --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..01cffc52e --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,347 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..530dd3acf --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,347 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json new file mode 100644 index 000000000..65562905a --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -0,0 +1,50 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.2, + "learning_rate": 0.0008445074561975979, + "one_minus_beta1": 0.11042418465, + "beta2": 0.9978504782314613, + "weight_decay": 0.08135402759553023, + "warmup_factor": 0.05 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.004958460849689891, + "one_minus_beta1": 0.13625575743, + "beta2": 0.6291854735396584, + "weight_decay": 0.1147386261512052, + "warmup_factor": 0.02 + }, + { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } +] + + + + + + diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..c54202e56 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,360 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..dd42743e2 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,360 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..57da48167 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,362 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..ef6e84c94 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,362 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch From a2b64e9d387fd4463ddd1948bff90c46c7f96d88 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 02:52:43 +0000 Subject: [PATCH 112/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index d3d908b5c..1d76c436c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -98,6 +98,11 @@ def init_model_fn( input_size = self.num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) + tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), + console_kwargs={'force_terminal': False, + 'force_jupyter': False, + 'width': 240},) + print(tabulate_fn(fake_inputs, train=False)) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( {'params': params_rng, 'dropout': dropout_rng}, From ce7fcd95cac8e07800d78c87ec1c673c13229e8b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 02:53:31 +0000 Subject: [PATCH 113/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 1d76c436c..657772360 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -102,7 +102,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240},) - print(tabulate_fn(fake_inputs, train=False)) + print(tabulate_fn(fake_inputs, train=False)) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( {'params': params_rng, 'dropout': dropout_rng}, From 26f78642379d12b5391bc048a91379cb8df14bde Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 03:06:29 +0000 Subject: [PATCH 114/169] fix --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 657772360..9e8b9736e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -98,16 +98,16 @@ def init_model_fn( input_size = self.num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) - tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), - console_kwargs={'force_terminal': False, - 'force_jupyter': False, - 'width': 240},) - print(tabulate_fn(fake_inputs, train=False)) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( {'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape, jnp.float32)) fake_inputs = jnp.ones(input_shape, jnp.float32) + tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), + console_kwargs={'force_terminal': False, + 'force_jupyter': False, + 'width': 240},) + print(tabulate_fn(fake_inputs, train=False)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) From 8243820142295a3e9b0f40bfc0ceea6d30a75d8a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 03:11:19 +0000 Subject: [PATCH 115/169] debug --- tests/modeldiffs/criteo1tb_layernorm/compare.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index d46d57d15..fa34b9f63 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -18,6 +18,8 @@ def key_transform(k): new_key = [] s_count = None + print('key') + print(key) for i in k: if 'Sequential' in i: s_count = int(i.split('_')[1]) From c44b5bdc44c741e1e629fef954c56d71c622c5ca Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 03:12:22 +0000 Subject: [PATCH 116/169] debug --- tests/modeldiffs/criteo1tb_layernorm/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index fa34b9f63..fc306e625 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -19,7 +19,7 @@ def key_transform(k): new_key = [] s_count = None print('key') - print(key) + print(k) for i in k: if 'Sequential' in i: s_count = int(i.split('_')[1]) From 6913e9225989a97d54d105087cd07742e25901e5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 03:15:29 +0000 Subject: [PATCH 117/169] fix --- tests/modeldiffs/criteo1tb_layernorm/compare.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index fc306e625..130d0d84b 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -30,6 +30,9 @@ def key_transform(k): i = i.replace('Linear', 'Dense') name, count = i.split('_') i = name + '_' + str(s_count * 3 + int(count)) + if 'LayerNorm' in i: + name, count = i.split('_') + i = name + '_' + str(s_count * 3 + int(count)) elif 'weight' in i: i = i.replace('weight', 'kernel') From 6fc183002b77adb11c78e8fd7650e380e90e355b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 03:49:56 +0000 Subject: [PATCH 118/169] debugigng --- tests/modeldiffs/criteo1tb/compare.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 8c5881a8e..cb6806596 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -18,6 +18,8 @@ def key_transform(k): new_key = [] s_count = None + print('key') + print(k) for i in k: if 'Sequential' in i: s_count = int(i.split('_')[1]) From a140fc059424ecf8fef34a1ffc7d2f681cf8326b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 04:02:59 +0000 Subject: [PATCH 119/169] fix --- tests/modeldiffs/criteo1tb_layernorm/compare.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 130d0d84b..e56d65126 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -32,7 +32,11 @@ def key_transform(k): i = name + '_' + str(s_count * 3 + int(count)) if 'LayerNorm' in i: name, count = i.split('_') - i = name + '_' + str(s_count * 3 + int(count)) + # There is a layernorm on embedding between bottom and top MLP + if s_count is not None: + i = name + '_' + str(s_count * 4 + int(count)) + else: + i = name + '_' + str(3) elif 'weight' in i: i = i.replace('weight', 'kernel') From 9550c15f8bb67e4cdc82ed41d069087167cc13e4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 04:06:27 +0000 Subject: [PATCH 120/169] debug --- tests/modeldiffs/criteo1tb_layernorm/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index e56d65126..a8ee234aa 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -39,8 +39,8 @@ def key_transform(k): i = name + '_' + str(3) elif 'weight' in i: i = i.replace('weight', 'kernel') - new_key.append(i) + print(new_key) return tuple(new_key) From bcf68ed1d4f6c57f94f22d84b9d72b4c0fd80b4e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 04:09:37 +0000 Subject: [PATCH 121/169] fix test --- tests/modeldiffs/criteo1tb_layernorm/compare.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index a8ee234aa..d786bb11c 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -18,6 +18,7 @@ def key_transform(k): new_key = [] s_count = None + layer_norm = False print('key') print(k) for i in k: @@ -31,6 +32,7 @@ def key_transform(k): name, count = i.split('_') i = name + '_' + str(s_count * 3 + int(count)) if 'LayerNorm' in i: + layer_norm = True name, count = i.split('_') # There is a layernorm on embedding between bottom and top MLP if s_count is not None: @@ -38,7 +40,10 @@ def key_transform(k): else: i = name + '_' + str(3) elif 'weight' in i: - i = i.replace('weight', 'kernel') + if layer_norm: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') new_key.append(i) print(new_key) return tuple(new_key) From 85be241839b4917c93b4fb50a797d9f05f1703a1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 04:12:41 +0000 Subject: [PATCH 122/169] remove debugging statements --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 1 - tests/modeldiffs/criteo1tb_layernorm/compare.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 8ef0cdb2e..eff71d86f 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -214,7 +214,6 @@ def __init__(self, input_dim = dense_dim self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): - print(module) if isinstance(module, nn.Linear): limit = math.sqrt(6. / (module.in_features + module.out_features)) nn.init.uniform_(module.weight.data, -limit, limit) diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index d786bb11c..9b36ef1b6 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -19,8 +19,6 @@ def key_transform(k): new_key = [] s_count = None layer_norm = False - print('key') - print(k) for i in k: if 'Sequential' in i: s_count = int(i.split('_')[1]) @@ -45,7 +43,6 @@ def key_transform(k): else: i = i.replace('weight', 'kernel') new_key.append(i) - print(new_key) return tuple(new_key) From 987686c85d52d310c44460349f774d80078fe949 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 04:21:09 +0000 Subject: [PATCH 123/169] add jax model summary helper fn --- algorithmic_efficiency/logger_utils.py | 4 +++- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 5 ----- tests/utils.py | 11 +++++++++++ 3 files changed, 14 insertions(+), 6 deletions(-) create mode 100644 tests/utils.py diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index b7bde226a..9a881523a 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -2,6 +2,8 @@ import collections import json +import flax.linen as nn +import jax import logging import os.path import platform @@ -341,4 +343,4 @@ def set_up_loggers(train_dir: str, events_dir=train_dir, configs=configs, hyperparameters=hyperparameters) - return metrics_logger + return metrics_logger \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 9e8b9736e..d3d908b5c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -103,11 +103,6 @@ def init_model_fn( {'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape, jnp.float32)) fake_inputs = jnp.ones(input_shape, jnp.float32) - tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), - console_kwargs={'force_terminal': False, - 'force_jupyter': False, - 'width': 240},) - print(tabulate_fn(fake_inputs, train=False)) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..27dac8dee --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,11 @@ +import flax.linen as nn +import jax + + +def print_jax_model_summary(model, fake_inputs): + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate(model, jax.random.PRNGKey(0), + console_kwargs={'force_terminal': False, + 'force_jupyter': False, + 'width': 240},) + print(tabulate_fn(fake_inputs, train=False)) \ No newline at end of file From c6e26dafc4501d51a5fac3058c878efac278cbe8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 04:28:07 +0000 Subject: [PATCH 124/169] formatting --- .../criteo1tb/criteo1tb_jax/workload.py | 4 ++-- .../criteo1tb/criteo1tb_pytorch/models.py | 11 +++++++---- algorithmic_efficiency/workloads/workloads.py | 17 +++++++++++++---- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index d3d908b5c..bc68cbfd3 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -76,7 +76,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None, tabulate: Optional[bool] = False, - ) -> spec.ModelInitState: + ) -> spec.ModelInitState: """Only dropout is used.""" del aux_dropout_rate if self.use_resnet: @@ -185,7 +185,7 @@ class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) embed_dim: int = 256 - + @property def use_resnet(self) -> bool: """Whether or not to use residual connections in the model.""" diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index eff71d86f..f5f1d32fa 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -7,7 +7,8 @@ class DenseBlock(nn.Module): - """Dense block with optional residual connection.""""" + """Dense block with optional residual connection.""" "" + def __init__(self, module, resnet=False): super().__init__() self.module = module @@ -18,7 +19,8 @@ def forward(self, x): return self.module(x) + x else: return self.module(x) - + + class DotInteract(nn.Module): """Performs feature interaction operation between dense or sparse features.""" @@ -48,6 +50,7 @@ class DLRMResNet(nn.Module): mlp_top_dims: dimensions of dense layers of the top mlp. embed_dim: embedding dimension. """ + def __init__(self, vocab_size, num_dense_features=13, @@ -86,7 +89,7 @@ def __init__(self, block = [] block.append(nn.Linear(input_dim, dense_dim)) block.append(nn.ReLU(inplace=True)) - block = nn.Sequential(*block) + block = nn.Sequential(*block) if layer_idx > 0: block = DenseBlock(block, resnet=True) else: @@ -104,7 +107,7 @@ def __init__(self, math.sqrt(1. / module.out_features)) self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) - + # TODO: Write down the formula here instead of the constant. fan_in = 634 num_layers_top = len(self.mlp_top_dims) diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index c809e8234..410fdb1f3 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -57,14 +57,23 @@ 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, } -BASE_WORKLOADS = ['criteo1tb', 'fastmri', ' imagenet_resnet', 'imagenet_vit', - 'librispeech_conformer', 'librispeech_deepspeech', - 'ogbg', 'wmt'] +BASE_WORKLOADS = [ + 'criteo1tb', + 'fastmri', + ' imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'ogbg', + 'wmt' +] + def get_base_workload_name(workload_name): for base_workload_name in BASE_WORKLOADS: if base_workload_name in workload_name: - return base_workload_name + return base_workload_name + def convert_filepath_to_module(path: str): base, extension = os.path.splitext(path) From 90e8c8093be97657f298e074b0f085dee4f5eb20 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 17:24:18 +0000 Subject: [PATCH 125/169] update jax criteo workload --- .../criteo1tb/criteo1tb_jax/models.py | 74 +++++++++---------- .../criteo1tb/criteo1tb_jax/workload.py | 2 + .../workloads}/utils.py | 0 3 files changed, 39 insertions(+), 37 deletions(-) rename {tests => algorithmic_efficiency/workloads}/utils.py (100%) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 04a7c485d..8c6b21cb7 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -22,7 +22,7 @@ class DLRMResNet(nn.Module): num_dense_features: int = 13 mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) - embed_dim: int = 256 + embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. @@ -31,80 +31,80 @@ def __call__(self, x, train): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - # Bottom MLP + activation_fn = model_utils.ACTIVATIONS[self.activation_function] + + # bottom mlp mlp_bottom_dims = self.mlp_bottom_dims bot_mlp_input = nn.Dense( mlp_bottom_dims[0], kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_bottom_dims[0])), - )( - bot_mlp_input) - bot_mlp_input = nn.relu(bot_mlp_input) + stddev=1.0 / mlp_bottom_dims[0]**0.5), + )(bot_mlp_input) + bot_mlp_input = activation_fn(bot_mlp_input) for dense_dim in mlp_bottom_dims[1:]: x = nn.Dense( dense_dim, kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), - )( - bot_mlp_input) - bot_mlp_input += nn.relu(x) - - bot_mlp_output = bot_mlp_input - batch_size = bot_mlp_output.shape[0] - feature_stack = jnp.reshape(bot_mlp_output, - [batch_size, -1, self.embed_dim]) + bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), + )(bot_mlp_input) + bot_mlp_input += activation_fn(x) + + base_init_fn = jnn.initializers.uniform(scale=1.0) + if self.embedding_init_multiplier is None: + embedding_init_multiplier = 1 / self.vocab_size**0.5 + else: + embedding_init_multiplier = self.embedding_init_multiplier # Embedding table init and lookup for a single unified table. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / - jnp.sqrt(self.vocab_size)) + return base_init_fn(key, shape, dtype) * embedding_init_multiplier - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) + embedding_table = self.param( + 'embedding_table', + scaled_init, + [self.vocab_size, self.embed_dim]) - idx_lookup = jnp.reshape(idx_lookup, [-1]) embed_features = embedding_table[idx_lookup] - embed_features = jnp.reshape(embed_features, - [batch_size, -1, self.embed_dim]) - feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) - dot_interact_output = dot_interact(concat_features=feature_stack) - top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], - axis=-1) + batch_size = bot_mlp_input.shape[0] + embed_features = jnp.reshape( + embed_features, (batch_size, 26 * self.embed_dim)) + top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims + num_layers_top = len(mlp_top_dims) top_mlp_input = nn.Dense( mlp_top_dims[0], kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), + stddev=(2.0 / (mlp_input_dim + mlp_top_dims[0]))**0.5), bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( + stddev=(1.0 / mlp_top_dims[0])**0.5))( top_mlp_input) - top_mlp_input = nn.relu(top_mlp_input) + top_mlp_input = activation_fn(top_mlp_input) for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: fan_in = mlp_top_dims[layer_idx - 1] x = nn.Dense( fan_out, kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + stddev=(2.0 / (fan_in + fan_out))**0.5), bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( + stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))( top_mlp_input) - x = nn.relu(x) + x = activation_fn(x) if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: - x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + x = nn.Dropout( + rate=self.dropout_rate, deterministic=not train)(x) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. logits = nn.Dense( 1, kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))( - top_mlp_input) + stddev=(2.0 / (mlp_top_dims[-2] + 1))**0.5), + bias_init=jnn.initializers.normal( + stddev=1.0))(top_mlp_input) return logits diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index bc68cbfd3..c1d09d079 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -14,6 +14,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models from algorithmic_efficiency.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload +from algorithmic_efficiency.workloads import utils class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): @@ -103,6 +104,7 @@ def init_model_fn( {'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape, jnp.float32)) fake_inputs = jnp.ones(input_shape, jnp.float32) + utils.print_jax_model_summary(self._model, fake_inputs) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/tests/utils.py b/algorithmic_efficiency/workloads/utils.py similarity index 100% rename from tests/utils.py rename to algorithmic_efficiency/workloads/utils.py From e44020868183e143b692e2fa02b1d88bc818b321 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 22 Nov 2023 17:31:29 +0000 Subject: [PATCH 126/169] criteo workload variant --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 8c6b21cb7..afe853504 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -31,8 +31,6 @@ def __call__(self, x, train): bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - activation_fn = model_utils.ACTIVATIONS[self.activation_function] - # bottom mlp mlp_bottom_dims = self.mlp_bottom_dims @@ -42,7 +40,7 @@ def __call__(self, x, train): bias_init=jnn.initializers.normal( stddev=1.0 / mlp_bottom_dims[0]**0.5), )(bot_mlp_input) - bot_mlp_input = activation_fn(bot_mlp_input) + bot_mlp_input = nn.relu(bot_mlp_input) for dense_dim in mlp_bottom_dims[1:]: x = nn.Dense( @@ -50,7 +48,7 @@ def __call__(self, x, train): kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), )(bot_mlp_input) - bot_mlp_input += activation_fn(x) + bot_mlp_input += nn.relu(x) base_init_fn = jnn.initializers.uniform(scale=1.0) if self.embedding_init_multiplier is None: @@ -82,7 +80,7 @@ def scaled_init(key, shape, dtype=jnp.float_): bias_init=jnn.initializers.normal( stddev=(1.0 / mlp_top_dims[0])**0.5))( top_mlp_input) - top_mlp_input = activation_fn(top_mlp_input) + top_mlp_input = nn.relu(top_mlp_input) for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: fan_in = mlp_top_dims[layer_idx - 1] x = nn.Dense( @@ -92,7 +90,7 @@ def scaled_init(key, shape, dtype=jnp.float_): bias_init=jnn.initializers.normal( stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))( top_mlp_input) - x = activation_fn(x) + x = nn.relu(x) if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: x = nn.Dropout( rate=self.dropout_rate, deterministic=not train)(x) From 626420c3312810f8a25ae8abb577a1c7ecfa20a7 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 17:37:45 +0000 Subject: [PATCH 127/169] fixes --- .../criteo1tb/criteo1tb_jax/models.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index afe853504..bb86328fe 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -51,14 +51,10 @@ def __call__(self, x, train): bot_mlp_input += nn.relu(x) base_init_fn = jnn.initializers.uniform(scale=1.0) - if self.embedding_init_multiplier is None: - embedding_init_multiplier = 1 / self.vocab_size**0.5 - else: - embedding_init_multiplier = self.embedding_init_multiplier # Embedding table init and lookup for a single unified table. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size def scaled_init(key, shape, dtype=jnp.float_): - return base_init_fn(key, shape, dtype) * embedding_init_multiplier + return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) embedding_table = self.param( 'embedding_table', @@ -76,9 +72,9 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.Dense( mlp_top_dims[0], kernel_init=jnn.initializers.normal( - stddev=(2.0 / (mlp_input_dim + mlp_top_dims[0]))**0.5), + stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), bias_init=jnn.initializers.normal( - stddev=(1.0 / mlp_top_dims[0])**0.5))( + stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( top_mlp_input) top_mlp_input = nn.relu(top_mlp_input) for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: @@ -86,9 +82,9 @@ def scaled_init(key, shape, dtype=jnp.float_): x = nn.Dense( fan_out, kernel_init=jnn.initializers.normal( - stddev=(2.0 / (fan_in + fan_out))**0.5), + stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), bias_init=jnn.initializers.normal( - stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))( + stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( top_mlp_input) x = nn.relu(x) if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: @@ -100,9 +96,9 @@ def scaled_init(key, shape, dtype=jnp.float_): logits = nn.Dense( 1, kernel_init=jnn.initializers.normal( - stddev=(2.0 / (mlp_top_dims[-2] + 1))**0.5), + stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), bias_init=jnn.initializers.normal( - stddev=1.0))(top_mlp_input) + stddev=jnp.sqrt(1.0)))(top_mlp_input) return logits From e070fe32d1a50856a17122cc04f077846d169729 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 17:48:28 +0000 Subject: [PATCH 128/169] debug --- tests/modeldiffs/criteo1tb_resnet/compare.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 37d63a519..e44c199c4 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -6,6 +6,7 @@ import jax import numpy as np import torch +import jax.numpy as jnp from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ @@ -68,7 +69,14 @@ def sd_transform(sd): 'targets': torch.randint(low=0, high=1, size=(2,)), 'weights': torch.ones(2), } + + init_fake_batch_size = 2 + num_categorical_features = 26 + input_size = 13 + num_categorical_features + input_shape = (init_fake_batch_size, input_size) + fake_inputs = jnp.ones(input_shape, jnp.float32) jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch['inputs'] = fake_inputs # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( From 412e4fdb7fd81d30cb682887e18c8fcd56de26d3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 17:59:11 +0000 Subject: [PATCH 129/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 5 +++-- tests/modeldiffs/criteo1tb_layernorm/compare.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index c1d09d079..bcd1b9f41 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -96,9 +96,10 @@ def init_model_fn( params_rng, dropout_rng = jax.random.split(rng) init_fake_batch_size = 2 num_categorical_features = 26 - input_size = self.num_dense_features + num_categorical_features + num_dense_features = 13 + input_size = num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) - + print(input_shape) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( {'params': params_rng, 'dropout': dropout_rng}, diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 9b36ef1b6..1f16e8b97 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -81,7 +81,7 @@ def sd_transform(sd): jax_model_kwargs = dict( augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, + # mode=spec.ForwardPassMode.EVAL, rng=jax.random.PRNGKey(0), update_batch_norm=False) From c9823265ccb98cb2d9e183459fa8ee9ee5820382 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:02:51 +0000 Subject: [PATCH 130/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index bcd1b9f41..f722cbde6 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -99,6 +99,7 @@ def init_model_fn( num_dense_features = 13 input_size = num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) + print('Input Shape') print(input_shape) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( From f960a0abe79b0faab9280518e2737da97d144099 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:04:40 +0000 Subject: [PATCH 131/169] debugging --- .../criteo1tb/criteo1tb_jax/models.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index bb86328fe..dfcef083d 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -102,32 +102,32 @@ def scaled_init(key, shape, dtype=jnp.float_): return logits -def dot_interact(concat_features): - """Performs feature interaction operation between dense or sparse features. - Input tensors represent dense or sparse features. - Pre-condition: The tensors have been stacked along dimension 1. - Args: - concat_features: Array of features with shape [B, n_features, feature_dim]. - Returns: - activations: Array representing interacted features. - """ - batch_size = concat_features.shape[0] - - # Interact features, select upper or lower-triangular portion, and reshape. - xactions = jnp.matmul(concat_features, - jnp.transpose(concat_features, [0, 2, 1])) - feature_dim = xactions.shape[-1] - - indices = jnp.array(jnp.triu_indices(feature_dim)) - num_elems = indices.shape[1] - indices = jnp.tile(indices, [1, batch_size]) - indices0 = jnp.reshape( - jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), - [1, -1]) - indices = tuple(jnp.concatenate((indices0, indices), 0)) - activations = xactions[indices] - activations = jnp.reshape(activations, [batch_size, -1]) - return activations +# def dot_interact(concat_features): +# """Performs feature interaction operation between dense or sparse features. +# Input tensors represent dense or sparse features. +# Pre-condition: The tensors have been stacked along dimension 1. +# Args: +# concat_features: Array of features with shape [B, n_features, feature_dim]. +# Returns: +# activations: Array representing interacted features. +# """ +# batch_size = concat_features.shape[0] + +# # Interact features, select upper or lower-triangular portion, and reshape. +# xactions = jnp.matmul(concat_features, +# jnp.transpose(concat_features, [0, 2, 1])) +# feature_dim = xactions.shape[-1] + +# indices = jnp.array(jnp.triu_indices(feature_dim)) +# num_elems = indices.shape[1] +# indices = jnp.tile(indices, [1, batch_size]) +# indices0 = jnp.reshape( +# jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), +# [1, -1]) +# indices = tuple(jnp.concatenate((indices0, indices), 0)) +# activations = xactions[indices] +# activations = jnp.reshape(activations, [batch_size, -1]) +# return activations class DlrmSmall(nn.Module): From 122fbf42a46f5b64aac78b7b9854f70514bfb6f2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:06:54 +0000 Subject: [PATCH 132/169] debugging --- .../criteo1tb/criteo1tb_jax/models.py | 162 +++++++++--------- 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index dfcef083d..d772f7595 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -130,85 +130,85 @@ def scaled_init(key, shape, dtype=jnp.float_): # return activations -class DlrmSmall(nn.Module): - """Define a DLRM-Small model. - - Parameters: - vocab_size: vocab size of embedding table. - num_dense_features: number of dense features as the bottom mlp input. - mlp_bottom_dims: dimensions of dense layers of the bottom mlp. - mlp_top_dims: dimensions of dense layers of the top mlp. - embed_dim: embedding dimension. - """ - - vocab_size: int = 32 * 128 * 1024 # 4_194_304. - num_dense_features: int = 13 - mlp_bottom_dims: Sequence[int] = (512, 256, 128) - mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) - embed_dim: int = 128 - dropout_rate: float = 0.0 - use_layer_norm: bool = False - - @nn.compact - def __call__(self, x, train): - bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) - cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - - # Bottom MLP. - for dense_dim in self.mlp_bottom_dims: - bot_mlp_input = nn.Dense( - dense_dim, - kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), - )( - bot_mlp_input) - bot_mlp_input = nn.relu(bot_mlp_input) - if self.use_layer_norm: - bot_mlp_input = nn.LayerNorm()(bot_mlp_input) - bot_mlp_output = bot_mlp_input - batch_size = bot_mlp_output.shape[0] - feature_stack = jnp.reshape(bot_mlp_output, - [batch_size, -1, self.embed_dim]) - - # Embedding table look-up. - idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - - def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / - jnp.sqrt(self.vocab_size)) - - embedding_table = self.param('embedding_table', - scaled_init, [self.vocab_size, self.embed_dim]) +# class DlrmSmall(nn.Module): +# """Define a DLRM-Small model. + +# Parameters: +# vocab_size: vocab size of embedding table. +# num_dense_features: number of dense features as the bottom mlp input. +# mlp_bottom_dims: dimensions of dense layers of the bottom mlp. +# mlp_top_dims: dimensions of dense layers of the top mlp. +# embed_dim: embedding dimension. +# """ - idx_lookup = jnp.reshape(idx_lookup, [-1]) - embed_features = embedding_table[idx_lookup] - embed_features = jnp.reshape(embed_features, - [batch_size, -1, self.embed_dim]) - if self.use_layer_norm: - embed_features = nn.LayerNorm()(embed_features) - feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) - dot_interact_output = dot_interact(concat_features=feature_stack) - top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], - axis=-1) - mlp_input_dim = top_mlp_input.shape[1] - mlp_top_dims = self.mlp_top_dims - num_layers_top = len(mlp_top_dims) - for layer_idx, fan_out in enumerate(mlp_top_dims): - fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] - top_mlp_input = nn.Dense( - fan_out, - kernel_init=jnn.initializers.normal( - stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), - bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( - top_mlp_input) - if layer_idx < (num_layers_top - 1): - top_mlp_input = nn.relu(top_mlp_input) - if self.use_layer_norm: - top_mlp_input = nn.LayerNorm()(top_mlp_input) - if (self.dropout_rate is not None and self.dropout_rate > 0.0 and - layer_idx == num_layers_top - 2): - top_mlp_input = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)( - top_mlp_input) - logits = top_mlp_input - return logits +# vocab_size: int = 32 * 128 * 1024 # 4_194_304. +# num_dense_features: int = 13 +# mlp_bottom_dims: Sequence[int] = (512, 256, 128) +# mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) +# embed_dim: int = 128 +# dropout_rate: float = 0.0 +# use_layer_norm: bool = False + +# @nn.compact +# def __call__(self, x, train): +# bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) +# cat_features = jnp.asarray(cat_features, dtype=jnp.int32) + +# # Bottom MLP. +# for dense_dim in self.mlp_bottom_dims: +# bot_mlp_input = nn.Dense( +# dense_dim, +# kernel_init=jnn.initializers.glorot_uniform(), +# bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), +# )( +# bot_mlp_input) +# bot_mlp_input = nn.relu(bot_mlp_input) +# if self.use_layer_norm: +# bot_mlp_input = nn.LayerNorm()(bot_mlp_input) +# bot_mlp_output = bot_mlp_input +# batch_size = bot_mlp_output.shape[0] +# feature_stack = jnp.reshape(bot_mlp_output, +# [batch_size, -1, self.embed_dim]) + +# # Embedding table look-up. +# idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + +# def scaled_init(key, shape, dtype=jnp.float_): +# return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / +# jnp.sqrt(self.vocab_size)) + +# embedding_table = self.param('embedding_table', +# scaled_init, [self.vocab_size, self.embed_dim]) + +# idx_lookup = jnp.reshape(idx_lookup, [-1]) +# embed_features = embedding_table[idx_lookup] +# embed_features = jnp.reshape(embed_features, +# [batch_size, -1, self.embed_dim]) +# if self.use_layer_norm: +# embed_features = nn.LayerNorm()(embed_features) +# feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) +# dot_interact_output = dot_interact(concat_features=feature_stack) +# top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], +# axis=-1) +# mlp_input_dim = top_mlp_input.shape[1] +# mlp_top_dims = self.mlp_top_dims +# num_layers_top = len(mlp_top_dims) +# for layer_idx, fan_out in enumerate(mlp_top_dims): +# fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] +# top_mlp_input = nn.Dense( +# fan_out, +# kernel_init=jnn.initializers.normal( +# stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), +# bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( +# top_mlp_input) +# if layer_idx < (num_layers_top - 1): +# top_mlp_input = nn.relu(top_mlp_input) +# if self.use_layer_norm: +# top_mlp_input = nn.LayerNorm()(top_mlp_input) +# if (self.dropout_rate is not None and self.dropout_rate > 0.0 and +# layer_idx == num_layers_top - 2): +# top_mlp_input = nn.Dropout( +# rate=self.dropout_rate, deterministic=not train)( +# top_mlp_input) +# logits = top_mlp_input +# return logits From bf12d5b08e3cc44405f506c6a6d36cc4325c86b1 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:14:03 +0000 Subject: [PATCH 133/169] shape debugging --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index d772f7595..0ae239999 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -66,6 +66,8 @@ def scaled_init(key, shape, dtype=jnp.float_): embed_features = jnp.reshape( embed_features, (batch_size, 26 * self.embed_dim)) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) + print("top mlp input shape") + print(jnp.shape(top_mlp_input)) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) From a234a8a1fe01a697c405dfa35f02577b822fce64 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:19:03 +0000 Subject: [PATCH 134/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index f722cbde6..0ae5a6b29 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -188,7 +188,7 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) - embed_dim: int = 256 + embed_dim: int = 128 @property def use_resnet(self) -> bool: From 456573a3cd02a4a385515335873565fdcaa85fb3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:20:51 +0000 Subject: [PATCH 135/169] debugging --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 3671572ef..c3c2ed935 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -263,7 +263,7 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) - embed_dim: int = 256 + embed_dim: int = 128 @property def use_resnet(self) -> bool: From 12f5da682122de1497a6a971ca9a752195a63ec0 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:40:49 +0000 Subject: [PATCH 136/169] debugging --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index f5f1d32fa..27ad62ff3 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -30,6 +30,10 @@ def __init__(self, num_sparse_features): num_sparse_features + 1) def forward(self, dense_features, sparse_features): + print("Dense features shape") + print(dense_features.shape) + print(sparse_features.shape) + print(dense_features.unsqueeze(1).shape) combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), dim=1) interactions = torch.bmm(combined_values, From 438deafe12bbbe48ea60b24e5d4e9a39e752fe41 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:54:07 +0000 Subject: [PATCH 137/169] debugging --- .../criteo1tb/criteo1tb_jax/models.py | 223 +++++++++--------- .../criteo1tb/criteo1tb_jax/workload.py | 1 - .../criteo1tb/criteo1tb_pytorch/models.py | 2 +- .../criteo1tb/criteo1tb_pytorch/workload.py | 1 - 4 files changed, 114 insertions(+), 113 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 0ae239999..d4cea60ca 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -53,6 +53,8 @@ def __call__(self, x, train): base_init_fn = jnn.initializers.uniform(scale=1.0) # Embedding table init and lookup for a single unified table. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + print("idx shape") + print(jnp.shape(idx_lookup)) def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) @@ -62,6 +64,7 @@ def scaled_init(key, shape, dtype=jnp.float_): [self.vocab_size, self.embed_dim]) embed_features = embedding_table[idx_lookup] + print(jnp.shape(embed_features)) batch_size = bot_mlp_input.shape[0] embed_features = jnp.reshape( embed_features, (batch_size, 26 * self.embed_dim)) @@ -104,113 +107,113 @@ def scaled_init(key, shape, dtype=jnp.float_): return logits -# def dot_interact(concat_features): -# """Performs feature interaction operation between dense or sparse features. -# Input tensors represent dense or sparse features. -# Pre-condition: The tensors have been stacked along dimension 1. -# Args: -# concat_features: Array of features with shape [B, n_features, feature_dim]. -# Returns: -# activations: Array representing interacted features. -# """ -# batch_size = concat_features.shape[0] - -# # Interact features, select upper or lower-triangular portion, and reshape. -# xactions = jnp.matmul(concat_features, -# jnp.transpose(concat_features, [0, 2, 1])) -# feature_dim = xactions.shape[-1] - -# indices = jnp.array(jnp.triu_indices(feature_dim)) -# num_elems = indices.shape[1] -# indices = jnp.tile(indices, [1, batch_size]) -# indices0 = jnp.reshape( -# jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), -# [1, -1]) -# indices = tuple(jnp.concatenate((indices0, indices), 0)) -# activations = xactions[indices] -# activations = jnp.reshape(activations, [batch_size, -1]) -# return activations - - -# class DlrmSmall(nn.Module): -# """Define a DLRM-Small model. - -# Parameters: -# vocab_size: vocab size of embedding table. -# num_dense_features: number of dense features as the bottom mlp input. -# mlp_bottom_dims: dimensions of dense layers of the bottom mlp. -# mlp_top_dims: dimensions of dense layers of the top mlp. -# embed_dim: embedding dimension. -# """ - -# vocab_size: int = 32 * 128 * 1024 # 4_194_304. -# num_dense_features: int = 13 -# mlp_bottom_dims: Sequence[int] = (512, 256, 128) -# mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) -# embed_dim: int = 128 -# dropout_rate: float = 0.0 -# use_layer_norm: bool = False - -# @nn.compact -# def __call__(self, x, train): -# bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) -# cat_features = jnp.asarray(cat_features, dtype=jnp.int32) - -# # Bottom MLP. -# for dense_dim in self.mlp_bottom_dims: -# bot_mlp_input = nn.Dense( -# dense_dim, -# kernel_init=jnn.initializers.glorot_uniform(), -# bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), -# )( -# bot_mlp_input) -# bot_mlp_input = nn.relu(bot_mlp_input) -# if self.use_layer_norm: -# bot_mlp_input = nn.LayerNorm()(bot_mlp_input) -# bot_mlp_output = bot_mlp_input -# batch_size = bot_mlp_output.shape[0] -# feature_stack = jnp.reshape(bot_mlp_output, -# [batch_size, -1, self.embed_dim]) - -# # Embedding table look-up. -# idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - -# def scaled_init(key, shape, dtype=jnp.float_): -# return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / -# jnp.sqrt(self.vocab_size)) - -# embedding_table = self.param('embedding_table', -# scaled_init, [self.vocab_size, self.embed_dim]) - -# idx_lookup = jnp.reshape(idx_lookup, [-1]) -# embed_features = embedding_table[idx_lookup] -# embed_features = jnp.reshape(embed_features, -# [batch_size, -1, self.embed_dim]) -# if self.use_layer_norm: -# embed_features = nn.LayerNorm()(embed_features) -# feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) -# dot_interact_output = dot_interact(concat_features=feature_stack) -# top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], -# axis=-1) -# mlp_input_dim = top_mlp_input.shape[1] -# mlp_top_dims = self.mlp_top_dims -# num_layers_top = len(mlp_top_dims) -# for layer_idx, fan_out in enumerate(mlp_top_dims): -# fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] -# top_mlp_input = nn.Dense( -# fan_out, -# kernel_init=jnn.initializers.normal( -# stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), -# bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( -# top_mlp_input) -# if layer_idx < (num_layers_top - 1): -# top_mlp_input = nn.relu(top_mlp_input) -# if self.use_layer_norm: -# top_mlp_input = nn.LayerNorm()(top_mlp_input) -# if (self.dropout_rate is not None and self.dropout_rate > 0.0 and -# layer_idx == num_layers_top - 2): -# top_mlp_input = nn.Dropout( -# rate=self.dropout_rate, deterministic=not train)( -# top_mlp_input) -# logits = top_mlp_input -# return logits +def dot_interact(concat_features): + """Performs feature interaction operation between dense or sparse features. + Input tensors represent dense or sparse features. + Pre-condition: The tensors have been stacked along dimension 1. + Args: + concat_features: Array of features with shape [B, n_features, feature_dim]. + Returns: + activations: Array representing interacted features. + """ + batch_size = concat_features.shape[0] + + # Interact features, select upper or lower-triangular portion, and reshape. + xactions = jnp.matmul(concat_features, + jnp.transpose(concat_features, [0, 2, 1])) + feature_dim = xactions.shape[-1] + + indices = jnp.array(jnp.triu_indices(feature_dim)) + num_elems = indices.shape[1] + indices = jnp.tile(indices, [1, batch_size]) + indices0 = jnp.reshape( + jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), + [1, -1]) + indices = tuple(jnp.concatenate((indices0, indices), 0)) + activations = xactions[indices] + activations = jnp.reshape(activations, [batch_size, -1]) + return activations + + +class DlrmSmall(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + vocab_size: int = 32 * 128 * 1024 # 4_194_304. + num_dense_features: int = 13 + mlp_bottom_dims: Sequence[int] = (512, 256, 128) + mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) + embed_dim: int = 128 + dropout_rate: float = 0.0 + use_layer_norm: bool = False + + @nn.compact + def __call__(self, x, train): + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) + cat_features = jnp.asarray(cat_features, dtype=jnp.int32) + + # Bottom MLP. + for dense_dim in self.mlp_bottom_dims: + bot_mlp_input = nn.Dense( + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)), + )( + bot_mlp_input) + bot_mlp_input = nn.relu(bot_mlp_input) + if self.use_layer_norm: + bot_mlp_input = nn.LayerNorm()(bot_mlp_input) + bot_mlp_output = bot_mlp_input + batch_size = bot_mlp_output.shape[0] + feature_stack = jnp.reshape(bot_mlp_output, + [batch_size, -1, self.embed_dim]) + + # Embedding table look-up. + idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + + def scaled_init(key, shape, dtype=jnp.float_): + return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / + jnp.sqrt(self.vocab_size)) + + embedding_table = self.param('embedding_table', + scaled_init, [self.vocab_size, self.embed_dim]) + + idx_lookup = jnp.reshape(idx_lookup, [-1]) + embed_features = embedding_table[idx_lookup] + embed_features = jnp.reshape(embed_features, + [batch_size, -1, self.embed_dim]) + if self.use_layer_norm: + embed_features = nn.LayerNorm()(embed_features) + feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) + dot_interact_output = dot_interact(concat_features=feature_stack) + top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], + axis=-1) + mlp_input_dim = top_mlp_input.shape[1] + mlp_top_dims = self.mlp_top_dims + num_layers_top = len(mlp_top_dims) + for layer_idx, fan_out in enumerate(mlp_top_dims): + fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1] + top_mlp_input = nn.Dense( + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))( + top_mlp_input) + if layer_idx < (num_layers_top - 1): + top_mlp_input = nn.relu(top_mlp_input) + if self.use_layer_norm: + top_mlp_input = nn.LayerNorm()(top_mlp_input) + if (self.dropout_rate is not None and self.dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + top_mlp_input = nn.Dropout( + rate=self.dropout_rate, deterministic=not train)( + top_mlp_input) + logits = top_mlp_input + return logits diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 0ae5a6b29..5007e94f6 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -188,7 +188,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) - embed_dim: int = 128 @property def use_resnet(self) -> bool: diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 27ad62ff3..d70ec3ef0 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -61,7 +61,7 @@ def __init__(self, num_sparse_features=26, mlp_bottom_dims=(256, 256, 256), mlp_top_dims=(256, 256, 256, 256, 1), - embed_dim=256, + embed_dim=128, dropout_rate=0.0, use_layer_norm=False): super().__init__() diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index c3c2ed935..eab47738e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -263,7 +263,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) - embed_dim: int = 128 @property def use_resnet(self) -> bool: From 49b45f6a6f93020b059d7f1c25e75670b916ddc2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 18:56:50 +0000 Subject: [PATCH 138/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index d4cea60ca..0eda62d04 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -52,6 +52,9 @@ def __call__(self, x, train): base_init_fn = jnn.initializers.uniform(scale=1.0) # Embedding table init and lookup for a single unified table. + print("cat features") + print(cat_features) + print(jnp.shape(jnp.reshape(cat_features, [-1]))) idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size print("idx shape") print(jnp.shape(idx_lookup)) From 3ed36331dcbc83e8345982850dec89c047b4e032 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 19:38:10 +0000 Subject: [PATCH 139/169] criteo variants --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 6 ------ .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 12 ++++-------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 0eda62d04..443fe134d 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -52,12 +52,7 @@ def __call__(self, x, train): base_init_fn = jnn.initializers.uniform(scale=1.0) # Embedding table init and lookup for a single unified table. - print("cat features") - print(cat_features) - print(jnp.shape(jnp.reshape(cat_features, [-1]))) idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - print("idx shape") - print(jnp.shape(idx_lookup)) def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) @@ -67,7 +62,6 @@ def scaled_init(key, shape, dtype=jnp.float_): [self.vocab_size, self.embed_dim]) embed_features = embedding_table[idx_lookup] - print(jnp.shape(embed_features)) batch_size = bot_mlp_input.shape[0] embed_features = jnp.reshape( embed_features, (batch_size, 26 * self.embed_dim)) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index d70ec3ef0..3fd431cd6 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -110,10 +110,8 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) - - # TODO: Write down the formula here instead of the constant. - fan_in = 634 + # Number of sparse features = 26 + fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] num_layers_top = len(self.mlp_top_dims) mlp_top_blocks = [] for layer_idx, fan_out in enumerate(self.mlp_top_dims): @@ -159,12 +157,10 @@ def forward(self, x): embedded_sparse = embedding_table[idx_lookup] embedded_sparse = torch.reshape(embedded_sparse, [batch_size, -1, self.embed_dim]) - # Dot product interactions. - concatenated_dense = self.dot_interact( - dense_features=embedded_dense, sparse_features=embedded_sparse) + top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) # Final MLP. - logits = self.top_mlp(concatenated_dense) + logits = self.top_mlp(top_mlp_input) return logits From 834ac9c0f6709f7555976dc108875d882b638d41 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 19:51:53 +0000 Subject: [PATCH 140/169] debugging --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 3fd431cd6..32fdd66d3 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -157,6 +157,8 @@ def forward(self, x): embedded_sparse = embedding_table[idx_lookup] embedded_sparse = torch.reshape(embedded_sparse, [batch_size, -1, self.embed_dim]) + print(embedded_sparse.shape) + print(embedded_dense.shape) top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) # Final MLP. From e4c7b3426b6b0e2103385f494ed9a3dbf2a6ce21 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 19:54:04 +0000 Subject: [PATCH 141/169] fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 32fdd66d3..b659c321f 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -156,7 +156,7 @@ def forward(self, x): embedding_table = torch.cat(self.embedding_table_chucks, dim=0) embedded_sparse = embedding_table[idx_lookup] embedded_sparse = torch.reshape(embedded_sparse, - [batch_size, -1, self.embed_dim]) + [batch_size, 26 * self.embed_dim]) print(embedded_sparse.shape) print(embedded_dense.shape) top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) From ac6baf76ee1130d6835d06a6a2d7307a6a2b7443 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 20:51:40 +0000 Subject: [PATCH 142/169] embedding initialization criteo --- .../criteo1tb/criteo1tb_jax/models.py | 9 +++++-- .../criteo1tb/criteo1tb_jax/workload.py | 24 ++++++++++++++++++- .../criteo1tb/criteo1tb_pytorch/models.py | 13 +++++++--- .../criteo1tb/criteo1tb_pytorch/workload.py | 24 ++++++++++++++++++- .../workloads/criteo1tb/workload.py | 4 ++++ algorithmic_efficiency/workloads/workloads.py | 4 ++++ 6 files changed, 71 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 443fe134d..791096938 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -25,6 +25,7 @@ class DLRMResNet(nn.Module): embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. + embedding_init_multiplier: float = None # Unused @nn.compact def __call__(self, x, train): @@ -150,6 +151,7 @@ class DlrmSmall(nn.Module): embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False + embedding_init_multiplier = None @nn.compact def __call__(self, x, train): @@ -175,9 +177,12 @@ def __call__(self, x, train): # Embedding table look-up. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + if self.embedding_init_multiplier is None: + embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size) + else: + embedding_init_multiplier = self.embedding_init_multiplier def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / - jnp.sqrt(self.vocab_size)) + return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype), embedding_init_multiplier) embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 5007e94f6..d27adb7c1 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -91,7 +91,8 @@ def init_model_fn( mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, dropout_rate=dropout_rate, - use_layer_norm=self.use_layer_norm) + use_layer_norm=self.use_layer_norm, + use_embedding_init_multiplier=self.embedding_init_multiplier) params_rng, dropout_rng = jax.random.split(rng) init_fake_batch_size = 2 @@ -201,3 +202,24 @@ def validation_target_value(self) -> float: @property def test_target_value(self) -> float: return 0.126468 + + +class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): + + @property + def use_layer_norm(self) -> bool: + """Whether or not to use LayerNorm in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.124286 + + @property + def test_target_value(self) -> float: + # Todo + return 0.126725 + + @property + def embedding_init_multiplier(self) -> float: + return 1.0 \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index b659c321f..1a1e04948 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -63,7 +63,8 @@ def __init__(self, mlp_top_dims=(256, 256, 256, 256, 1), embed_dim=128, dropout_rate=0.0, - use_layer_norm=False): + use_layer_norm=False, + embedding_init_multiplier=None): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -185,7 +186,8 @@ def __init__(self, mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, dropout_rate=0.0, - use_layer_norm=False): + use_layer_norm=False, + embedding_init_multiplier=None): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -200,7 +202,12 @@ def __init__(self, num_chucks = 4 assert vocab_size % num_chucks == 0 self.embedding_table_chucks = [] - scale = 1.0 / torch.sqrt(self.vocab_size) + + if self.embedding_init_multiplier is None: + scale = 1.0 / torch.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier + for i in range(num_chucks): chunk = nn.Parameter( torch.Tensor(self.vocab_size // num_chucks, self.embed_dim)) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index eab47738e..32f187031 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -86,7 +86,8 @@ def init_model_fn( mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, dropout_rate=dropout_rate, - use_layer_norm=self.use_layer_norm) + use_layer_norm=self.use_layer_norm, + embedding_init_mulitplier=self.embedding_init_multiplier) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -276,3 +277,24 @@ def validation_target_value(self) -> float: @property def test_target_value(self) -> float: return 0.126468 + + +class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): + + @property + def use_layer_norm(self) -> bool: + """Whether or not to use LayerNorm in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.124286 + + @property + def test_target_value(self) -> float: + # Todo + return 0.126725 + + @property + def embedding_init_multiplier(self) -> float: + return 1.0 \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 4b2dcbf19..17661e670 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -43,6 +43,10 @@ def use_resnet(self) -> bool: """Whether or not to use residual connections in the model.""" return False + @property + def embedding_init_multiplier(self) -> float: + return None + @property def validation_target_value(self) -> float: return 0.123735 diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 410fdb1f3..8891a6e18 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -24,6 +24,10 @@ 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' }, + 'criteo1tb_embed_init': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallEmbeddingInitWorkload' + }, 'criteo1tb_resnet': { 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' From e138c830b3f034ae293c04e884ef5303dabf8272 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 20:52:00 +0000 Subject: [PATCH 143/169] test embedding init --- .../criteo1tb_embed_init/__init__.py | 0 .../criteo1tb_embed_init/compare.py | 85 +++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 tests/modeldiffs/criteo1tb_embed_init/__init__.py create mode 100644 tests/modeldiffs/criteo1tb_embed_init/compare.py diff --git a/tests/modeldiffs/criteo1tb_embed_init/__init__.py b/tests/modeldiffs/criteo1tb_embed_init/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py new file mode 100644 index 000000000..fb582fbec --- /dev/null +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -0,0 +1,85 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ + Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ + Criteo1TbDlrmSmallEmbedInitWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + s_count = None + print('key') + print(k) + for i in k: + if 'Sequential' in i: + s_count = int(i.split('_')[1]) + continue + if 'Embedding' in i: + return ('embedding_table',) + if 'Linear' in i: + i = i.replace('Linear', 'Dense') + name, count = i.split('_') + i = name + '_' + str(s_count * 3 + int(count)) + elif 'weight' in i: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + out = {} + chunks = [] + for k in sd: + if 'embedding_chunk' in ''.join(k): + chunks.append(sd[k].cpu()) + else: + out[k] = sd[k] + out[('embedding_table',)] = torch.cat(chunks, dim=0) + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = { + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), + } + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) From 132223b1f98ff51d3f0e048a6b869e16d7f10923 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 20:56:05 +0000 Subject: [PATCH 144/169] fix --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index d27adb7c1..fd252fd83 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -92,7 +92,7 @@ def init_model_fn( embed_dim=self.embed_dim, dropout_rate=dropout_rate, use_layer_norm=self.use_layer_norm, - use_embedding_init_multiplier=self.embedding_init_multiplier) + embedding_init_multiplier=self.embedding_init_multiplier) params_rng, dropout_rng = jax.random.split(rng) init_fake_batch_size = 2 From 071ddf3dbad369538a9468fd616897372ce0a096 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:00:39 +0000 Subject: [PATCH 145/169] add embedding init multiplier --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 791096938..d4d0c6474 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -151,7 +151,7 @@ class DlrmSmall(nn.Module): embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False - embedding_init_multiplier = None + embedding_init_multiplier: float = None @nn.compact def __call__(self, x, train): From 134d0bb9f840b745b3732a9f77a0148f6a1fffc3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:05:30 +0000 Subject: [PATCH 146/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index d4d0c6474..33eba1930 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -181,6 +181,7 @@ def __call__(self, x, train): embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size) else: embedding_init_multiplier = self.embedding_init_multiplier + embedding_init_multiplier = 1. def scaled_init(key, shape, dtype=jnp.float_): return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype), embedding_init_multiplier) From 955b70d697f03c7acd665d0d2b68987103c5b06f Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:08:53 +0000 Subject: [PATCH 147/169] test --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 33eba1930..c9495095e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -177,10 +177,10 @@ def __call__(self, x, train): # Embedding table look-up. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - if self.embedding_init_multiplier is None: - embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size) - else: - embedding_init_multiplier = self.embedding_init_multiplier + # if self.embedding_init_multiplier is None: + # embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size) + # else: + # embedding_init_multiplier = self.embedding_init_multiplier embedding_init_multiplier = 1. def scaled_init(key, shape, dtype=jnp.float_): return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype), embedding_init_multiplier) From dadd4d640df41638d4c122589b855d8d18d5792c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:21:13 +0000 Subject: [PATCH 148/169] fix --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 32f187031..1daf71df4 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -87,7 +87,7 @@ def init_model_fn( embed_dim=self.embed_dim, dropout_rate=dropout_rate, use_layer_norm=self.use_layer_norm, - embedding_init_mulitplier=self.embedding_init_multiplier) + embedding_init_multitplier=self.embedding_init_multiplier) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) From 8e832a341f866f0330db8851c9f41b1f84b865ac Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:22:27 +0000 Subject: [PATCH 149/169] fix --- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 1daf71df4..6eba695c0 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -87,7 +87,7 @@ def init_model_fn( embed_dim=self.embed_dim, dropout_rate=dropout_rate, use_layer_norm=self.use_layer_norm, - embedding_init_multitplier=self.embedding_init_multiplier) + embedding_init_multiplier=self.embedding_init_multiplier) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) From 07daeee197075c027919920adab364302a7aef43 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:26:10 +0000 Subject: [PATCH 150/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index c9495095e..f47dc0fad 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -181,9 +181,9 @@ def __call__(self, x, train): # embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size) # else: # embedding_init_multiplier = self.embedding_init_multiplier - embedding_init_multiplier = 1. + # embedding_init_multiplier = 1. def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype), embedding_init_multiplier) + return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype), 1 / jnp.sqrt(self.vocab_size)) embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) From 263108f3bc3a71543d7daa6e4e5ee33fc3a494df Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:29:00 +0000 Subject: [PATCH 151/169] debug --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 3 ++- tests/modeldiffs/criteo1tb/compare.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index f47dc0fad..5a9855408 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -183,7 +183,8 @@ def __call__(self, x, train): # embedding_init_multiplier = self.embedding_init_multiplier # embedding_init_multiplier = 1. def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype), 1 / jnp.sqrt(self.vocab_size)) + return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / + jnp.sqrt(self.vocab_size)) embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index cb6806596..9a95f3656 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -18,8 +18,6 @@ def key_transform(k): new_key = [] s_count = None - print('key') - print(k) for i in k: if 'Sequential' in i: s_count = int(i.split('_')[1]) @@ -32,6 +30,7 @@ def key_transform(k): i = name + '_' + str(s_count * 3 + int(count)) elif 'weight' in i: i = i.replace('weight', 'kernel') + new_key.append(i) return tuple(new_key) From a7dd1615d83f00034d70e45a1ddc39945657cfca Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:30:58 +0000 Subject: [PATCH 152/169] fix --- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 1a1e04948..67e365f63 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -195,6 +195,7 @@ def __init__(self, self.mlp_bottom_dims = mlp_bottom_dims self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim + self.embedding_init_multiplier = embedding_init_multiplier # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation From 40a4cde7aab411a72f922a5b625f2da1cc317f19 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:34:45 +0000 Subject: [PATCH 153/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 5a9855408..b52e50ba2 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -178,13 +178,13 @@ def __call__(self, x, train): idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size # if self.embedding_init_multiplier is None: - # embedding_init_multiplier = 1 / jnp.sqrt(self.vocab_size) + # scale = 1 / jnp.sqrt(self.vocab_size) # else: - # embedding_init_multiplier = self.embedding_init_multiplier + # scale = self.embedding_init_multiplier # embedding_init_multiplier = 1. def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / - jnp.sqrt(self.vocab_size)) + return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * + 1/jnp.sqrt(self.vocab_size)) embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) From 1740325481b0f09f2c6b7c09ed717ba564bb2a74 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:36:38 +0000 Subject: [PATCH 154/169] debugging --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index b52e50ba2..d7d6ef60f 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -177,14 +177,13 @@ def __call__(self, x, train): # Embedding table look-up. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size - # if self.embedding_init_multiplier is None: - # scale = 1 / jnp.sqrt(self.vocab_size) - # else: - # scale = self.embedding_init_multiplier - # embedding_init_multiplier = 1. + if self.embedding_init_multiplier is None: + scale = 1 / jnp.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier def scaled_init(key, shape, dtype=jnp.float_): return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * - 1/jnp.sqrt(self.vocab_size)) + scale) embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) From b05f5109b82e0ee772c576c0b1bc17fe2c138c75 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:40:57 +0000 Subject: [PATCH 155/169] fix --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 5 ----- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 7 +------ 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index fd252fd83..818015e2c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -206,11 +206,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): - @property - def use_layer_norm(self) -> bool: - """Whether or not to use LayerNorm in the model.""" - return True - @property def validation_target_value(self) -> float: return 0.124286 diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 6eba695c0..5633082c5 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -281,11 +281,6 @@ def test_target_value(self) -> float: class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): - @property - def use_layer_norm(self) -> bool: - """Whether or not to use LayerNorm in the model.""" - return True - @property def validation_target_value(self) -> float: return 0.124286 @@ -294,7 +289,7 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: # Todo return 0.126725 - + @property def embedding_init_multiplier(self) -> float: return 1.0 \ No newline at end of file From 3cf57a77c05c3b0886fb9b2f68ea929e66ef43ff Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:50:40 +0000 Subject: [PATCH 156/169] add tests --- .github/workflows/regression_tests_variants.yml | 11 ++++++++++- docker/scripts/startup.sh | 2 +- tests/modeldiffs/criteo1tb_embed_init/compare.py | 2 -- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index d581c80fc..ec7a07e76 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -72,5 +72,14 @@ jobs: run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - + criteo_resnet_pytorch: + runs-on: self-hosted + needs: build_and_push_pytorch_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized workload + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_embed_init -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + \ No newline at end of file diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index c6ac2d701..06bd2c0da 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -114,7 +114,7 @@ VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ "wmt" "mnist") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_vit" "fastmri" "ogbg" \ "wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist" \ - "criteo1tb_resnet" "criteo1tb_layernorm") + "criteo1tb_resnet" "criteo1tb_layernorm" "criteo1tb_embed_init") # Set data and experiment paths diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index fb582fbec..719484037 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -18,8 +18,6 @@ def key_transform(k): new_key = [] s_count = None - print('key') - print(k) for i in k: if 'Sequential' in i: s_count = int(i.split('_')[1]) From a5568edd1aa131925bf81095b5e461dabb0169fd Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:56:06 +0000 Subject: [PATCH 157/169] clean up; --- algorithmic_efficiency/logger_utils.py | 4 +--- .../workloads/criteo1tb/criteo1tb_jax/models.py | 4 +--- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 4 ---- .../workloads/criteo1tb/criteo1tb_pytorch/models.py | 6 ------ tests/modeldiffs/torch2jax_utils.py | 9 ++++----- 5 files changed, 6 insertions(+), 21 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 9a881523a..b7bde226a 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -2,8 +2,6 @@ import collections import json -import flax.linen as nn -import jax import logging import os.path import platform @@ -343,4 +341,4 @@ def set_up_loggers(train_dir: str, events_dir=train_dir, configs=configs, hyperparameters=hyperparameters) - return metrics_logger \ No newline at end of file + return metrics_logger diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index d7d6ef60f..6baf9c6fc 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -22,7 +22,7 @@ class DLRMResNet(nn.Module): num_dense_features: int = 13 mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) - embed_dim: int = 128 + embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. embedding_init_multiplier: float = None # Unused @@ -67,8 +67,6 @@ def scaled_init(key, shape, dtype=jnp.float_): embed_features = jnp.reshape( embed_features, (batch_size, 26 * self.embed_dim)) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) - print("top mlp input shape") - print(jnp.shape(top_mlp_input)) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 818015e2c..44c3b06dd 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -100,14 +100,10 @@ def init_model_fn( num_dense_features = 13 input_size = num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) - print('Input Shape') - print(input_shape) init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( {'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape, jnp.float32)) - fake_inputs = jnp.ones(input_shape, jnp.float32) - utils.print_jax_model_summary(self._model, fake_inputs) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 67e365f63..4110b30fc 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -30,10 +30,6 @@ def __init__(self, num_sparse_features): num_sparse_features + 1) def forward(self, dense_features, sparse_features): - print("Dense features shape") - print(dense_features.shape) - print(sparse_features.shape) - print(dense_features.unsqueeze(1).shape) combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), dim=1) interactions = torch.bmm(combined_values, @@ -158,8 +154,6 @@ def forward(self, x): embedded_sparse = embedding_table[idx_lookup] embedded_sparse = torch.reshape(embedded_sparse, [batch_size, 26 * self.embed_dim]) - print(embedded_sparse.shape) - print(embedded_dense.shape) top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) # Final MLP. diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 9600cd204..d9264b400 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -77,11 +77,10 @@ def key_transform(self, k_transform_fn): } def value_transform(self, v_transform_fn): - for k in self.pytorch_sd: - self.pytorch_sd = { - k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) - for k in self.pytorch_sd - } + self.pytorch_sd = { + k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k]) + for k in self.pytorch_sd + } def sd_transform(self, sd_transform_fn): self.pytorch_sd = sd_transform_fn(self.pytorch_sd) From 76d1749e561f07f5e396ee2855d6792ffc9ab7e5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 21:59:36 +0000 Subject: [PATCH 158/169] formatting --- .../criteo1tb/criteo1tb_jax/models.py | 35 +++++++++---------- .../criteo1tb/criteo1tb_jax/workload.py | 4 +-- .../criteo1tb/criteo1tb_pytorch/workload.py | 4 +-- .../fastmri/fastmri_pytorch/workload.py | 4 +-- .../imagenet_pytorch/workload.py | 4 +-- .../librispeech_jax/spectrum_augmenter.py | 4 +-- .../librispeech_pytorch/workload.py | 9 +++-- .../workloads/mnist/workload.py | 4 +-- algorithmic_efficiency/workloads/utils.py | 15 ++++---- .../workloads/wmt/wmt_pytorch/models.py | 4 +-- 10 files changed, 41 insertions(+), 46 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 6baf9c6fc..025dde5f1 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -22,10 +22,10 @@ class DLRMResNet(nn.Module): num_dense_features: int = 13 mlp_bottom_dims: Sequence[int] = (256, 256, 256) mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) - embed_dim: int = 128 + embed_dim: int = 128 dropout_rate: float = 0.0 use_layer_norm: bool = False # Unused. - embedding_init_multiplier: float = None # Unused + embedding_init_multiplier: float = None # Unused @nn.compact def __call__(self, x, train): @@ -38,9 +38,9 @@ def __call__(self, x, train): bot_mlp_input = nn.Dense( mlp_bottom_dims[0], kernel_init=jnn.initializers.glorot_uniform(), - bias_init=jnn.initializers.normal( - stddev=1.0 / mlp_bottom_dims[0]**0.5), - )(bot_mlp_input) + bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5), + )( + bot_mlp_input) bot_mlp_input = nn.relu(bot_mlp_input) for dense_dim in mlp_bottom_dims[1:]: @@ -48,24 +48,24 @@ def __call__(self, x, train): dense_dim, kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), - )(bot_mlp_input) + )( + bot_mlp_input) bot_mlp_input += nn.relu(x) base_init_fn = jnn.initializers.uniform(scale=1.0) # Embedding table init and lookup for a single unified table. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) - embedding_table = self.param( - 'embedding_table', - scaled_init, - [self.vocab_size, self.embed_dim]) + embedding_table = self.param('embedding_table', + scaled_init, [self.vocab_size, self.embed_dim]) embed_features = embedding_table[idx_lookup] batch_size = bot_mlp_input.shape[0] - embed_features = jnp.reshape( - embed_features, (batch_size, 26 * self.embed_dim)) + embed_features = jnp.reshape(embed_features, + (batch_size, 26 * self.embed_dim)) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims @@ -89,8 +89,7 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input) x = nn.relu(x) if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: - x = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)(x) + x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. @@ -98,8 +97,8 @@ def scaled_init(key, shape, dtype=jnp.float_): 1, kernel_init=jnn.initializers.normal( stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), - bias_init=jnn.initializers.normal( - stddev=jnp.sqrt(1.0)))(top_mlp_input) + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))( + top_mlp_input) return logits @@ -179,9 +178,9 @@ def __call__(self, x, train): scale = 1 / jnp.sqrt(self.vocab_size) else: scale = self.embedding_init_multiplier + def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * - scale) + return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale) embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 44c3b06dd..b13c3498a 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -210,7 +210,7 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: # Todo return 0.126725 - + @property def embedding_init_multiplier(self) -> float: - return 1.0 \ No newline at end of file + return 1.0 diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 5633082c5..85bb602d1 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -289,7 +289,7 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: # Todo return 0.126725 - + @property def embedding_init_multiplier(self) -> float: - return 1.0 \ No newline at end of file + return 1.0 diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index daaea9e10..c3252feb8 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -247,9 +247,7 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index c0fcaaef3..cc9d2febc 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -282,9 +282,7 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..c16740629 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights + < multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index c4f4a1247..d2774d3b9 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -227,8 +227,9 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + fin_result.shape[1], + device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( + -1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -296,9 +297,7 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..959228755 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -214,8 +214,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/utils.py b/algorithmic_efficiency/workloads/utils.py index 27dac8dee..7719f91fb 100644 --- a/algorithmic_efficiency/workloads/utils.py +++ b/algorithmic_efficiency/workloads/utils.py @@ -1,11 +1,14 @@ import flax.linen as nn -import jax +import jax def print_jax_model_summary(model, fake_inputs): """Prints a summary of the jax module.""" - tabulate_fn = nn.tabulate(model, jax.random.PRNGKey(0), - console_kwargs={'force_terminal': False, - 'force_jupyter': False, - 'width': 240},) - print(tabulate_fn(fake_inputs, train=False)) \ No newline at end of file + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, 'force_jupyter': False, 'width': 240 + }, + ) + print(tabulate_fn(fake_inputs, train=False)) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index b787785a1..dc8ebea90 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -912,8 +912,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) + >= cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) From 284c30c37047ee2dea66acb7f7ca761845671020 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:02:35 +0000 Subject: [PATCH 159/169] reformat --- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +++- .../imagenet_resnet/imagenet_pytorch/workload.py | 4 +++- .../librispeech_jax/spectrum_augmenter.py | 4 ++-- .../librispeech_pytorch/workload.py | 9 +++++---- algorithmic_efficiency/workloads/mnist/workload.py | 4 +++- .../workloads/wmt/wmt_pytorch/models.py | 4 ++-- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index c3252feb8..daaea9e10 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -247,7 +247,9 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index cc9d2febc..c0fcaaef3 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -282,7 +282,9 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index c16740629..2a6f73d4d 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights - < multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights < + multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index d2774d3b9..c4f4a1247 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -227,9 +227,8 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], - device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( - -1, 1) + fin_result.shape[1], device=result.device).view( + 1, -1) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -297,7 +296,9 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 959228755..dcc195170 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -214,6 +214,8 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index dc8ebea90..b787785a1 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -912,8 +912,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) - >= cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) From 547e2254a5c372a1d7ccd048f0728f79ae5e2639 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:03:54 +0000 Subject: [PATCH 160/169] sorting imports --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 5 +++-- tests/modeldiffs/criteo1tb_resnet/compare.py | 2 +- tests/modeldiffs/diff.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index b13c3498a..ad473b04c 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -2,19 +2,20 @@ import functools from typing import Dict, Optional, Tuple + from absl import logging from flax import jax_utils +import flax.linen as nn import jax import jax.numpy as jnp import numpy as np -import flax.linen as nn from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads import utils from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models from algorithmic_efficiency.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload -from algorithmic_efficiency.workloads import utils class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index e44c199c4..114dadb0e 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -4,9 +4,9 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '' import jax +import jax.numpy as jnp import numpy as np import torch -import jax.numpy as jnp from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index d56115258..81eff8301 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -1,4 +1,5 @@ import logging + from flax import jax_utils import jax import numpy as np From a37d3fae0d4e49568b6e8eeae517342d65c5c15c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:07:18 +0000 Subject: [PATCH 161/169] fix --- tests/modeldiffs/criteo1tb_layernorm/compare.py | 2 +- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 1f16e8b97..3fc2a750a 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -35,7 +35,7 @@ def key_transform(k): # There is a layernorm on embedding between bottom and top MLP if s_count is not None: i = name + '_' + str(s_count * 4 + int(count)) - else: + else: i = name + '_' + str(3) elif 'weight' in i: if layer_norm: diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 114dadb0e..0b7004568 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -31,7 +31,7 @@ def key_transform(k): mlp_block_count = int(i.split('_')[1]) continue if 'DenseBlock' in i: - # off set resnet block count by 1 + # off set resnet block count by 1 # since first mlp layer has no resnet connection resnet_block_count = int(i.split('_')[1]) continue @@ -42,7 +42,7 @@ def key_transform(k): i = name + '_' + str(mlp_count * 3 + block_count) elif 'weight' in i: i = i.replace('weight', 'kernel') - new_key.append(i) + new_key.append(i) return tuple(new_key) From 1af2a8a595645c48984c5cd85dc0e96c1ec4fe72 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:11:39 +0000 Subject: [PATCH 162/169] remove unused imports --- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index ad473b04c..1d4adf516 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -3,9 +3,7 @@ import functools from typing import Dict, Optional, Tuple -from absl import logging from flax import jax_utils -import flax.linen as nn import jax import jax.numpy as jnp import numpy as np From 72ede87096b255ee946bf3799beeeb9037396242 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:20:36 +0000 Subject: [PATCH 163/169] pylint --- .../workloads/criteo1tb/criteo1tb_jax/models.py | 2 +- .../workloads/criteo1tb/criteo1tb_jax/workload.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index 025dde5f1..dc79332a3 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -180,7 +180,7 @@ def __call__(self, x, train): scale = self.embedding_init_multiplier def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale) + return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 1d4adf516..795a6b30e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -10,7 +10,6 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads import utils from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models from algorithmic_efficiency.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload From f6025a22a104195b51f5f9e69876c4f511b098c2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:27:09 +0000 Subject: [PATCH 164/169] add exception --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 37d89d136..c8f215e52 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -489,7 +489,7 @@ def score_submission_on_workload(workload: spec.Workload, data_selection = submission_module.data_selection try: global_batch_size = submission_module.get_batch_size(workload_name) - except: + except Exception: base_workload_name = workloads.get_base_workload_name(workload_name) global_batch_size = submission_module.get_batch_size(base_workload_name) # n_gpus has to be set here, because we cannot call the first Jax operation From f947e23b1f4316473ffbb0b6892f9b5ac0c8eb00 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:44:53 +0000 Subject: [PATCH 165/169] add clarifying docs for submission example --- submissions/template/submission.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 83297a7d9..1b20f1f04 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,6 +4,8 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/RULES.md#disallowed-submissions for guidelines. """ +from algorithmic_efficiency import spec +from typing import Dict, List, Tuple, Iterator def init_optimizer_state(workload: spec.Workload, @@ -41,19 +43,13 @@ def update_params(workload: spec.Workload, def get_batch_size(workload_name): """ - Returns batch size for each workload. - Valid workload_name values are in - ["wmt", - "ogbg", - "criteo1tb", - "fastmri", - "imagenet_resnet", - "imagenet_vit", - "librispeech_deepspeech", - "librispeech_conformer"] + Gets batch size for workload. + Args: + workload_name (str): Valid workload_name values are in["wmt", "ogbg", "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", "librispeech_deepspeech", "librispeech_conformer"] Returns: - batch_size - + int: batch_size + Raises: + ValueError: If workload_name is not handled. """ pass From 765c363da12d3569c488e7841cdd1db6066e92cf Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:46:01 +0000 Subject: [PATCH 166/169] formatting --- submissions/template/submission.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 1b20f1f04..320f9dc4c 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -45,9 +45,11 @@ def get_batch_size(workload_name): """ Gets batch size for workload. Args: - workload_name (str): Valid workload_name values are in["wmt", "ogbg", "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", "librispeech_deepspeech", "librispeech_conformer"] + workload_name (str): Valid workload_name values are: "wmt", "ogbg", + "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", + "librispeech_deepspeech", "librispeech_conformer". Returns: - int: batch_size + int: batch_size Raises: ValueError: If workload_name is not handled. """ From 9da4996933d8e12f86ac63c6169064916d35c64e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 22:46:19 +0000 Subject: [PATCH 167/169] make exception specific --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index c8f215e52..dcf228595 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -489,7 +489,7 @@ def score_submission_on_workload(workload: spec.Workload, data_selection = submission_module.data_selection try: global_batch_size = submission_module.get_batch_size(workload_name) - except Exception: + except ValueError: base_workload_name = workloads.get_base_workload_name(workload_name) global_batch_size = submission_module.get_batch_size(base_workload_name) # n_gpus has to be set here, because we cannot call the first Jax operation From 96bcea6c3027962d1e740612e9c3c9845a0e161a Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 23:04:21 +0000 Subject: [PATCH 168/169] formatting --- submissions/template/submission.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 320f9dc4c..3dd01cae0 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,8 +4,9 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/RULES.md#disallowed-submissions for guidelines. """ +from typing import Dict, Iterator, List, Tuple + from algorithmic_efficiency import spec -from typing import Dict, List, Tuple, Iterator def init_optimizer_state(workload: spec.Workload, From 59384c6b5e107e3ee67e995d0e615b70045e0b4c Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Nov 2023 23:15:15 +0000 Subject: [PATCH 169/169] pylint --- tests/modeldiffs/criteo1tb_resnet/compare.py | 2 +- tests/modeldiffs/diff.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 0b7004568..b9dbbc80e 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -37,7 +37,7 @@ def key_transform(k): continue if 'Linear' in i: i = i.replace('Linear', 'Dense') - name, count = i.split('_') + name, _ = i.split('_') block_count = mlp_block_count if mlp_block_count else resnet_block_count i = name + '_' + str(mlp_count * 3 + block_count) elif 'weight' in i: diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index 81eff8301..f96fa672b 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -1,5 +1,3 @@ -import logging - from flax import jax_utils import jax import numpy as np