diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index 5a93e08b1..15eccba4c 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -44,7 +44,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_jax: runs-on: self-hosted needs: build_and_push_jax_docker_image @@ -53,7 +53,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_layernorm_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -62,7 +62,7 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_resnet_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image @@ -71,31 +71,15 @@ jobs: - name: Run containerized workload run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/algorithmic-efficiency:/algorithmic-efficiency -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false -librispeech_conformer_layernorm_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_conformer_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false -librispeech_conformer_attention_temperature_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_conformer_attention_temperature -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false -librispeech_conformer_gelu_jax: + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + criteo_resnet_pytorch: runs-on: self-hosted - needs: build_and_push_jax_docker_image + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_conformer_gelu -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false \ No newline at end of file + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_embed_init -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + + diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index d47f1b484..dc79332a3 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -7,6 +7,101 @@ import jax.numpy as jnp +class DLRMResNet(nn.Module): + """Define a DLRMResNet model. + + Parameters: + vocab_size: the size of a single unified embedding table. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + num_dense_features: number of dense features as the bottom mlp input. + embed_dim: embedding dimension. + """ + + vocab_size: int = 32 * 128 * 1024 # 4_194_304 + num_dense_features: int = 13 + mlp_bottom_dims: Sequence[int] = (256, 256, 256) + mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1) + embed_dim: int = 128 + dropout_rate: float = 0.0 + use_layer_norm: bool = False # Unused. + embedding_init_multiplier: float = None # Unused + + @nn.compact + def __call__(self, x, train): + bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) + cat_features = jnp.asarray(cat_features, dtype=jnp.int32) + + # bottom mlp + mlp_bottom_dims = self.mlp_bottom_dims + + bot_mlp_input = nn.Dense( + mlp_bottom_dims[0], + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5), + )( + bot_mlp_input) + bot_mlp_input = nn.relu(bot_mlp_input) + + for dense_dim in mlp_bottom_dims[1:]: + x = nn.Dense( + dense_dim, + kernel_init=jnn.initializers.glorot_uniform(), + bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5), + )( + bot_mlp_input) + bot_mlp_input += nn.relu(x) + + base_init_fn = jnn.initializers.uniform(scale=1.0) + # Embedding table init and lookup for a single unified table. + idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + + def scaled_init(key, shape, dtype=jnp.float_): + return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size) + + embedding_table = self.param('embedding_table', + scaled_init, [self.vocab_size, self.embed_dim]) + + embed_features = embedding_table[idx_lookup] + batch_size = bot_mlp_input.shape[0] + embed_features = jnp.reshape(embed_features, + (batch_size, 26 * self.embed_dim)) + top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) + mlp_input_dim = top_mlp_input.shape[1] + mlp_top_dims = self.mlp_top_dims + num_layers_top = len(mlp_top_dims) + top_mlp_input = nn.Dense( + mlp_top_dims[0], + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))( + top_mlp_input) + top_mlp_input = nn.relu(top_mlp_input) + for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: + fan_in = mlp_top_dims[layer_idx - 1] + x = nn.Dense( + fan_out, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (fan_in + fan_out))), + bias_init=jnn.initializers.normal( + stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))( + top_mlp_input) + x = nn.relu(x) + if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: + x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) + top_mlp_input += x + # In the DLRM model the last layer width is always 1. We can hardcode that + # below. + logits = nn.Dense( + 1, + kernel_init=jnn.initializers.normal( + stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))), + bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))( + top_mlp_input) + return logits + + def dot_interact(concat_features): """Performs feature interaction operation between dense or sparse features. Input tensors represent dense or sparse features. @@ -52,6 +147,8 @@ class DlrmSmall(nn.Module): mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) embed_dim: int = 128 dropout_rate: float = 0.0 + use_layer_norm: bool = False + embedding_init_multiplier: float = None @nn.compact def __call__(self, x, train): @@ -67,6 +164,8 @@ def __call__(self, x, train): )( bot_mlp_input) bot_mlp_input = nn.relu(bot_mlp_input) + if self.use_layer_norm: + bot_mlp_input = nn.LayerNorm()(bot_mlp_input) bot_mlp_output = bot_mlp_input batch_size = bot_mlp_output.shape[0] feature_stack = jnp.reshape(bot_mlp_output, @@ -75,9 +174,13 @@ def __call__(self, x, train): # Embedding table look-up. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + if self.embedding_init_multiplier is None: + scale = 1 / jnp.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier + def scaled_init(key, shape, dtype=jnp.float_): - return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) / - jnp.sqrt(self.vocab_size)) + return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale embedding_table = self.param('embedding_table', scaled_init, [self.vocab_size, self.embed_dim]) @@ -86,6 +189,8 @@ def scaled_init(key, shape, dtype=jnp.float_): embed_features = embedding_table[idx_lookup] embed_features = jnp.reshape(embed_features, [batch_size, -1, self.embed_dim]) + if self.use_layer_norm: + embed_features = nn.LayerNorm()(embed_features) feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) dot_interact_output = dot_interact(concat_features=feature_stack) top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], @@ -103,6 +208,8 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input) if layer_idx < (num_layers_top - 1): top_mlp_input = nn.relu(top_mlp_input) + if self.use_layer_norm: + top_mlp_input = nn.LayerNorm()(top_mlp_input) if (self.dropout_rate is not None and self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_input = nn.Dropout( diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index a76a70289..795a6b30e 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -73,23 +73,31 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + tabulate: Optional[bool] = False, + ) -> spec.ModelInitState: """Only dropout is used.""" del aux_dropout_rate - self._model = models.DlrmSmall( + if self.use_resnet: + model_class = models.DLRMResNet + else: + model_class = models.DlrmSmall + self._model = model_class( vocab_size=self.vocab_size, num_dense_features=self.num_dense_features, mlp_bottom_dims=self.mlp_bottom_dims, mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, - dropout_rate=dropout_rate) + dropout_rate=dropout_rate, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier) params_rng, dropout_rng = jax.random.split(rng) init_fake_batch_size = 2 num_categorical_features = 26 - input_size = self.num_dense_features + num_categorical_features + num_dense_features = 13 + input_size = num_dense_features + num_categorical_features input_shape = (init_fake_batch_size, input_size) - init_fn = functools.partial(self._model.init, train=False) initial_variables = jax.jit(init_fn)( {'params': params_rng, 'dropout': dropout_rng}, @@ -154,3 +162,53 @@ def _eval_batch(self, class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): vocab_size: int = 32 * 128 * 16 + + +class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload): + + @property + def use_layer_norm(self) -> bool: + """Whether or not to use LayerNorm in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.123744 + + @property + def test_target_value(self) -> float: + return 0.126152 + + +class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): + mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) + mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) + + @property + def use_resnet(self) -> bool: + """Whether or not to use residual connections in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.124027 + + @property + def test_target_value(self) -> float: + return 0.126468 + + +class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): + + @property + def validation_target_value(self) -> float: + return 0.124286 + + @property + def test_target_value(self) -> float: + # Todo + return 0.126725 + + @property + def embedding_init_multiplier(self) -> float: + return 1.0 diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index de6b4d1dd..4110b30fc 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -6,6 +6,21 @@ from torch import nn +class DenseBlock(nn.Module): + """Dense block with optional residual connection.""" "" + + def __init__(self, module, resnet=False): + super().__init__() + self.module = module + self.resnet = resnet + + def forward(self, x): + if self.resnet: + return self.module(x) + x + else: + return self.module(x) + + class DotInteract(nn.Module): """Performs feature interaction operation between dense or sparse features.""" @@ -25,6 +40,127 @@ def forward(self, dense_features, sparse_features): return torch.cat((dense_features, interactions_flat), dim=1) +class DLRMResNet(nn.Module): + """Define a DLRM-Small model. + + Parameters: + vocab_size: vocab size of embedding table. + num_dense_features: number of dense features as the bottom mlp input. + mlp_bottom_dims: dimensions of dense layers of the bottom mlp. + mlp_top_dims: dimensions of dense layers of the top mlp. + embed_dim: embedding dimension. + """ + + def __init__(self, + vocab_size, + num_dense_features=13, + num_sparse_features=26, + mlp_bottom_dims=(256, 256, 256), + mlp_top_dims=(256, 256, 256, 256, 1), + embed_dim=128, + dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): + super().__init__() + self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) + self.num_dense_features = num_dense_features + self.num_sparse_features = num_sparse_features + self.mlp_bottom_dims = mlp_bottom_dims + self.mlp_top_dims = mlp_top_dims + self.embed_dim = embed_dim + + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chucks = 4 + assert vocab_size % num_chucks == 0 + self.embedding_table_chucks = [] + scale = 1.0 / torch.sqrt(self.vocab_size) + for i in range(num_chucks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chucks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) + + input_dim = self.num_dense_features + bot_mlp_blocks = [] + for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims): + block = [] + block.append(nn.Linear(input_dim, dense_dim)) + block.append(nn.ReLU(inplace=True)) + block = nn.Sequential(*block) + if layer_idx > 0: + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) + bot_mlp_blocks.append(block) + input_dim = dense_dim + self.bot_mlp = nn.Sequential(*bot_mlp_blocks) + + for module in self.bot_mlp.modules(): + if isinstance(module, nn.Linear): + limit = math.sqrt(6. / (module.in_features + module.out_features)) + nn.init.uniform_(module.weight.data, -limit, limit) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + # Number of sparse features = 26 + fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1] + num_layers_top = len(self.mlp_top_dims) + mlp_top_blocks = [] + for layer_idx, fan_out in enumerate(self.mlp_top_dims): + block = [] + block.append(nn.Linear(fan_in, fan_out)) + if layer_idx < (num_layers_top - 1): + block.append(nn.ReLU(inplace=True)) + if (dropout_rate is not None and dropout_rate > 0.0 and + layer_idx == num_layers_top - 2): + block.append(nn.Dropout(p=dropout_rate)) + block = nn.Sequential(*block) + if (layer_idx != 0) and (layer_idx != num_layers_top - 1): + block = DenseBlock(block, resnet=True) + else: + block = DenseBlock(block) + mlp_top_blocks.append(block) + fan_in = fan_out + self.top_mlp = nn.Sequential(*mlp_top_blocks) + + for module in self.top_mlp.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_( + module.weight.data, + 0., + math.sqrt(2. / (module.in_features + module.out_features))) + nn.init.normal_(module.bias.data, + 0., + math.sqrt(1. / module.out_features)) + + def forward(self, x): + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( + x, [self.num_dense_features, self.num_sparse_features], 1) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, 26 * self.embed_dim]) + top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1) + + # Final MLP. + logits = self.top_mlp(top_mlp_input) + return logits + + class DlrmSmall(nn.Module): """Define a DLRM-Small model. @@ -43,7 +179,9 @@ def __init__(self, mlp_bottom_dims=(512, 256, 128), mlp_top_dims=(1024, 1024, 512, 256, 1), embed_dim=128, - dropout_rate=0.0): + dropout_rate=0.0, + use_layer_norm=False, + embedding_init_multiplier=None): super().__init__() self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32) self.num_dense_features = num_dense_features @@ -51,6 +189,7 @@ def __init__(self, self.mlp_bottom_dims = mlp_bottom_dims self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim + self.embedding_init_multiplier = embedding_init_multiplier # Ideally, we should use the pooled embedding implementation from # `TorchRec`. However, in order to have identical implementation @@ -58,7 +197,12 @@ def __init__(self, num_chucks = 4 assert vocab_size % num_chucks == 0 self.embedding_table_chucks = [] - scale = 1.0 / torch.sqrt(self.vocab_size) + + if self.embedding_init_multiplier is None: + scale = 1.0 / torch.sqrt(self.vocab_size) + else: + scale = self.embedding_init_multiplier + for i in range(num_chucks): chunk = nn.Parameter( torch.Tensor(self.vocab_size // num_chucks, self.embed_dim)) @@ -67,11 +211,13 @@ def __init__(self, self.register_parameter(f'embedding_chunk_{i}', chunk) self.embedding_table_chucks.append(chunk) - bottom_mlp_layers = [] input_dim = self.num_dense_features + bottom_mlp_layers = [] for dense_dim in self.mlp_bottom_dims: bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim)) bottom_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6)) input_dim = dense_dim self.bot_mlp = nn.Sequential(*bottom_mlp_layers) for module in self.bot_mlp.modules(): @@ -86,18 +232,24 @@ def __init__(self, # TODO: Write down the formula here instead of the constant. input_dims = 506 - top_mlp_layers = [] num_layers_top = len(self.mlp_top_dims) + top_mlp_layers = [] for layer_idx, fan_out in enumerate(self.mlp_top_dims): fan_in = input_dims if layer_idx == 0 \ else self.mlp_top_dims[layer_idx - 1] top_mlp_layers.append(nn.Linear(fan_in, fan_out)) if layer_idx < (num_layers_top - 1): top_mlp_layers.append(nn.ReLU(inplace=True)) + if use_layer_norm: + top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6)) if (dropout_rate is not None and dropout_rate > 0.0 and layer_idx == num_layers_top - 2): top_mlp_layers.append(nn.Dropout(p=dropout_rate)) self.top_mlp = nn.Sequential(*top_mlp_layers) + if use_layer_norm: + self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6) + else: + self.embed_ln = None for module in self.top_mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_( @@ -124,7 +276,8 @@ def forward(self, x): embedded_sparse = embedding_table[idx_lookup] embedded_sparse = torch.reshape(embedded_sparse, [batch_size, -1, self.embed_dim]) - + if self.embed_ln: + embedded_sparse = self.embed_ln(embedded_sparse) # Dot product interactions. concatenated_dense = self.dot_interact( dense_features=embedded_dense, sparse_features=embedded_sparse) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 55b68fb2f..85bb602d1 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -10,8 +10,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.models import \ - DlrmSmall +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch import models from algorithmic_efficiency.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload @@ -76,13 +75,19 @@ def init_model_fn( torch.random.manual_seed(rng[0]) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False - model = DlrmSmall( + if self.use_resnet: + model_class = models.DLRMResNet + else: + model_class = models.DlrmSmall + model = model_class( vocab_size=self.vocab_size, num_dense_features=self.num_dense_features, mlp_bottom_dims=self.mlp_bottom_dims, mlp_top_dims=self.mlp_top_dims, embed_dim=self.embed_dim, - dropout_rate=dropout_rate) + dropout_rate=dropout_rate, + use_layer_norm=self.use_layer_norm, + embedding_init_multiplier=self.embedding_init_multiplier) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -238,3 +243,53 @@ def _eval_batch(self, class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): vocab_size: int = 32 * 128 * 16 + + +class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload): + + @property + def use_layer_norm(self) -> bool: + """Whether or not to use LayerNorm in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.123744 + + @property + def test_target_value(self) -> float: + return 0.126152 + + +class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload): + mlp_bottom_dims: Tuple[int, int] = (256, 256, 256) + mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1) + + @property + def use_resnet(self) -> bool: + """Whether or not to use residual connections in the model.""" + return True + + @property + def validation_target_value(self) -> float: + return 0.124027 + + @property + def test_target_value(self) -> float: + return 0.126468 + + +class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): + + @property + def validation_target_value(self) -> float: + return 0.124286 + + @property + def test_target_value(self) -> float: + # Todo + return 0.126725 + + @property + def embedding_init_multiplier(self) -> float: + return 1.0 diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 13bd308fb..17661e670 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -33,6 +33,20 @@ def has_reached_validation_target(self, eval_result: Dict[str, float]) -> bool: return eval_result['validation/loss'] < self.validation_target_value + @property + def use_layer_norm(self) -> bool: + """Whether or not to use LayerNorm in the model.""" + return False + + @property + def use_resnet(self) -> bool: + """Whether or not to use residual connections in the model.""" + return False + + @property + def embedding_init_multiplier(self) -> float: + return None + @property def validation_target_value(self) -> float: return 0.123735 diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index b7f33b94b..e638df078 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -80,9 +80,7 @@ def _build_input_queue( weights = torch.as_tensor( batch['weights'], dtype=torch.bool, device=DEVICE) else: - weights = torch.ones((batch['targets'].shape[-1],), - dtype=torch.bool, - device=DEVICE) + weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE) # Send batch to other devices when using DDP. if USE_PYTORCH_DDP: dist.broadcast(inputs, src=0) diff --git a/algorithmic_efficiency/workloads/utils.py b/algorithmic_efficiency/workloads/utils.py new file mode 100644 index 000000000..7719f91fb --- /dev/null +++ b/algorithmic_efficiency/workloads/utils.py @@ -0,0 +1,14 @@ +import flax.linen as nn +import jax + + +def print_jax_model_summary(model, fake_inputs): + """Prints a summary of the jax module.""" + tabulate_fn = nn.tabulate( + model, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, 'force_jupyter': False, 'width': 240 + }, + ) + print(tabulate_fn(fake_inputs, train=False)) diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 78da438fe..617316f4a 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -24,6 +24,13 @@ 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload' }, +<<<<<<< HEAD +======= + 'criteo1tb_embed_init': { + 'workload_path': 'criteo1tb/criteo1tb', + 'workload_class_name': 'Criteo1TbDlrmSmallEmbeddingInitWorkload' + }, +>>>>>>> criteo_workload_variants 'criteo1tb_resnet': { 'workload_path': 'criteo1tb/criteo1tb', 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload' @@ -69,14 +76,23 @@ 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'}, } -BASE_WORKLOADS = ['criteo1tb', 'fastmri', ' imagenet_resnet', 'imagenet_vit', - 'librispeech_conformer', 'librispeech_deepspeech', - 'ogbg', 'wmt'] +BASE_WORKLOADS = [ + 'criteo1tb', + 'fastmri', + ' imagenet_resnet', + 'imagenet_vit', + 'librispeech_conformer', + 'librispeech_deepspeech', + 'ogbg', + 'wmt' +] + def get_base_workload_name(workload_name): for base_workload_name in BASE_WORKLOADS: if base_workload_name in workload_name: - return base_workload_name + return base_workload_name + def convert_filepath_to_module(path: str): base, extension = os.path.splitext(path) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 2f808b64b..f9ee2f138 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -388,7 +388,7 @@ def download_fastmri(data_dir, def extract(source, dest, mode='r:xz'): if not os.path.exists(dest): - os.path.makedirs(dest) + os.makedirs(dest) logging.info(f'Extracting {source} to {dest}') tar = tarfile.open(source, mode) logging.info('Opened tar') diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 644c4b368..c2f79d6b5 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -114,7 +114,7 @@ VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ "wmt" "mnist") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_vit" "fastmri" "ogbg" \ "wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist" \ - "criteo1tb_resnet" "criteo1tb_layernorm" \ + "criteo1tb_resnet" "criteo1tb_layernorm" "criteo_embed_init" \ "conformer_layernorm" "conformer_attention_temperature" \ "conformer_gelu") diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/reference_algorithms/prize_qualification_baselines/README.md new file mode 100644 index 000000000..614f87b32 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/README.md @@ -0,0 +1,83 @@ +# Prize Qualification Baselines +This directory contains the baseine(s) that submissions that must beat to qualify for prizes. + +TODO: link back to section in rules. + +## Externally Tuned Ruleset + +### JAX + +The prize qualification baseline submissions for jax are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: + +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json +``` + +### PyTorch + +The prize qualification baseline submissionss for PyTorch are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` + + +Example command: + +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json +``` + +## Self-tuning Ruleset + +### JAX + +The prize qualification baseline submissionss for jax are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_ruleset=self +``` + +### PyTorch + +The prize qualification baseline submissionss for PyTorch are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` + +Example command: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_ruleset=self +``` \ No newline at end of file diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..099613fcf --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..ef0c11c0d --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..01cffc52e --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,347 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..530dd3acf --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,347 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json new file mode 100644 index 000000000..65562905a --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -0,0 +1,50 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.2, + "learning_rate": 0.0008445074561975979, + "one_minus_beta1": 0.11042418465, + "beta2": 0.9978504782314613, + "weight_decay": 0.08135402759553023, + "warmup_factor": 0.05 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.004958460849689891, + "one_minus_beta1": 0.13625575743, + "beta2": 0.6291854735396584, + "weight_decay": 0.1147386261512052, + "warmup_factor": 0.02 + }, + { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } +] + + + + + + diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..c54202e56 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,360 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..dd42743e2 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,360 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..57da48167 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,362 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..ef6e84c94 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,362 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/submission_runner.py b/submission_runner.py index 37d89d136..dcf228595 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -489,7 +489,7 @@ def score_submission_on_workload(workload: spec.Workload, data_selection = submission_module.data_selection try: global_batch_size = submission_module.get_batch_size(workload_name) - except: + except ValueError: base_workload_name = workloads.get_base_workload_name(workload_name) global_batch_size = submission_module.get_batch_size(base_workload_name) # n_gpus has to be set here, because we cannot call the first Jax operation diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 83297a7d9..3dd01cae0 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,6 +4,9 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/RULES.md#disallowed-submissions for guidelines. """ +from typing import Dict, Iterator, List, Tuple + +from algorithmic_efficiency import spec def init_optimizer_state(workload: spec.Workload, @@ -41,19 +44,15 @@ def update_params(workload: spec.Workload, def get_batch_size(workload_name): """ - Returns batch size for each workload. - Valid workload_name values are in - ["wmt", - "ogbg", - "criteo1tb", - "fastmri", - "imagenet_resnet", - "imagenet_vit", - "librispeech_deepspeech", - "librispeech_conformer"] + Gets batch size for workload. + Args: + workload_name (str): Valid workload_name values are: "wmt", "ogbg", + "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit", + "librispeech_deepspeech", "librispeech_conformer". Returns: - batch_size - + int: batch_size + Raises: + ValueError: If workload_name is not handled. """ pass diff --git a/tests/modeldiffs/criteo1tb_embed_init/__init__.py b/tests/modeldiffs/criteo1tb_embed_init/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py new file mode 100644 index 000000000..719484037 --- /dev/null +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -0,0 +1,83 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ + Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ + Criteo1TbDlrmSmallEmbedInitWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + s_count = None + for i in k: + if 'Sequential' in i: + s_count = int(i.split('_')[1]) + continue + if 'Embedding' in i: + return ('embedding_table',) + if 'Linear' in i: + i = i.replace('Linear', 'Dense') + name, count = i.split('_') + i = name + '_' + str(s_count * 3 + int(count)) + elif 'weight' in i: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + out = {} + chunks = [] + for k in sd: + if 'embedding_chunk' in ''.join(k): + chunks.append(sd[k].cpu()) + else: + out[k] = sd[k] + out[('embedding_table',)] = torch.cat(chunks, dim=0) + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = { + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), + } + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/criteo1tb_layernorm/__init__.py b/tests/modeldiffs/criteo1tb_layernorm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py new file mode 100644 index 000000000..3fc2a750a --- /dev/null +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -0,0 +1,95 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ + Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ + Criteo1TbDlrmSmallLayerNormWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + s_count = None + layer_norm = False + for i in k: + if 'Sequential' in i: + s_count = int(i.split('_')[1]) + continue + if 'Embedding' in i: + return ('embedding_table',) + if 'Linear' in i: + i = i.replace('Linear', 'Dense') + name, count = i.split('_') + i = name + '_' + str(s_count * 3 + int(count)) + if 'LayerNorm' in i: + layer_norm = True + name, count = i.split('_') + # There is a layernorm on embedding between bottom and top MLP + if s_count is not None: + i = name + '_' + str(s_count * 4 + int(count)) + else: + i = name + '_' + str(3) + elif 'weight' in i: + if layer_norm: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + out = {} + chunks = [] + for k in sd: + if 'embedding_chunk' in ''.join(k): + chunks.append(sd[k].cpu()) + else: + out[k] = sd[k] + out[('embedding_table',)] = torch.cat(chunks, dim=0) + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = { + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), + } + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + + # Test outputs for identical weights and inputs. + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + # mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/criteo1tb_resnet/__init__.py b/tests/modeldiffs/criteo1tb_resnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py new file mode 100644 index 000000000..b9dbbc80e --- /dev/null +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -0,0 +1,102 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import jax.numpy as jnp +import numpy as np +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ + Criteo1TbDlrmSmallResNetWorkload as JaxWorkload +from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ + Criteo1TbDlrmSmallResNetWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff + + +def key_transform(k): + new_key = [] + mlp_count = None + resnet_block_count = None + mlp_block_count = None + for i in k: + if 'Embedding' in i: + return ('embedding_table',) + if 'Sequential' in i: + if mlp_count is None: + mlp_count = int(i.split('_')[1]) + else: + mlp_block_count = int(i.split('_')[1]) + continue + if 'DenseBlock' in i: + # off set resnet block count by 1 + # since first mlp layer has no resnet connection + resnet_block_count = int(i.split('_')[1]) + continue + if 'Linear' in i: + i = i.replace('Linear', 'Dense') + name, _ = i.split('_') + block_count = mlp_block_count if mlp_block_count else resnet_block_count + i = name + '_' + str(mlp_count * 3 + block_count) + elif 'weight' in i: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +def sd_transform(sd): + out = {} + chunks = [] + for k in sd: + if 'embedding_chunk' in ''.join(k): + chunks.append(sd[k].cpu()) + else: + out[k] = sd[k] + out[('embedding_table',)] = torch.cat(chunks, dim=0) + return out + + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + pyt_batch = { + 'inputs': torch.ones((2, 13 + 26)), + 'targets': torch.randint(low=0, high=1, size=(2,)), + 'weights': torch.ones(2), + } + + init_fake_batch_size = 2 + num_categorical_features = 26 + input_size = 13 + num_categorical_features + input_shape = (init_fake_batch_size, input_size) + fake_inputs = jnp.ones(input_shape, jnp.float32) + jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch['inputs'] = fake_inputs + + # Test outputs for identical weights and inputs. + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=None) diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index bc53de875..f96fa672b 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -56,5 +56,8 @@ def out_diff(jax_workload, out_p = out_transform(out_p) out_j = out_transform(out_j) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) + max_diff = np.abs(out_p.detach().numpy() - np.array(out_j)).max() + min_diff = np.abs(out_p.detach().numpy() - np.array(out_j)).min() + + print(f'Max fprop difference between jax and pytorch: {max_diff}') + print(f'Min fprop difference between jax and pytorch: {min_diff}') diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py index 1926ab0cc..d9264b400 100644 --- a/tests/modeldiffs/torch2jax_utils.py +++ b/tests/modeldiffs/torch2jax_utils.py @@ -105,6 +105,7 @@ def diff(self): if s_p == s_j: count += 1 else: + print('Difference in pytorch and jax key:') print(k, s_p, s_j) print(f'Number of values with identical shapes: {count}')