diff --git a/CHANGELOG.md b/CHANGELOG.md index d28b30000..987d8d20d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Change Log +## algoperf-benchmark-0.1.4 (2024-03-26) + +Upgrade CUDA version to CUDA 12.1: +- Upgrade CUDA version in Dockerfiles that will be used for scoring. +- Update Jax and PyTorch package version tags to use local CUDA installation. + +Add flag for completely disabling checkpointing. +- Note that we will run with checkpointing off at scoring time. + +Update Deepspeech and Conformer variant target setting configurations. +- Note that variant targets are not final. + +Fixed bug in scoring code to take best trial in a study for external-tuning ruleset. + +Added instructions for submission. + +Changed default number of workers for PyTorch data loaders to 0. Running with >0 may lead to incorrect eval results see https://github.com/mlcommons/algorithmic-efficiency/issues/732. + ## algoperf-benchmark-0.1.2 (2024-03-04) Workload variant additions and fixes: - Add Deepspeech workload variant diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 50ae4dfdb..006b972ec 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -388,7 +388,42 @@ python score_submissions.py --submission_directory We provide the scores and performance profiles for the [paper baseline algorithms](/reference_algorithms/paper_baselines/) in the "Baseline Results" section in [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179). -## Package Submission for Self-Reporting +## Package your Submission code + +If you have registered for the AlgoPerf competition you will receive +an email on 3/27/2024 with a link to a UI to upload a compressed submission folder. + +To package your submission modules please make sure your submission folder is structured as follows: + +```bash +submission_folder/ +├── external_tuning +│ ├── algorithm_name +│ │ ├── helper_module.py +│ │ ├── requirements.txt +│ │ ├── submission.py +│ │ └── tuning_search_space.json +│ └── other_algorithm_name +│ ├── requirements.txt +│ ├── submission.py +│ └── tuning_search_space.json +└── self_tuning + └── algorithm_name + ├── requirements.txt + └── submission.py +``` + +Specifically we require that: +1. There exist subdirectories in the the submission folder named after the ruleset: `external_tuning` or `self_tuning`. +2. The ruleset subdirectories contain directories named according to +some identifier of the algorithm. +3. Each algorithm subdirectory contains a `submission.py` module. Additional helper modules are allowed if prefer to you organize your code into multiple files. If there are additional python packages that have to be installed for the algorithm also include a `requirements.txt` with package names and versions in the algorithm subdirectory. +4. For `external_tuning` algorithms the algorithm subdirectory +should contain a `tuning_search_space.json`. + +To check that your submission folder meets the above requirements you can run the `submissions/repo_checker.py` script. + +## Package Logs for Self-Reporting Submissions To prepare your submission for self reporting run: ``` diff --git a/README.md b/README.md index 04052b4e9..3628caede 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,9 @@ > [!IMPORTANT] > Upcoming Deadline: -> Submission deadline: **April 04th, 2024** (*moved by a week*) \ -> For other key dates please see [Call for Submissions](/CALL_FOR_SUBMISSIONS.md). +> Submission deadline: **April 04th, 2024** (*moved by a week*). \ +> For submission instructions please see [Packaging your Submission Code](/GETTING_STARTED.md#package-your-submission-code) section in the Getting Started document.\ +> For other key dates please see [Call for Submissions](CALL_FOR_SUBMISSIONS.md). ## Table of Contents diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 68e9a9cfe..cf1ea6c32 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -26,11 +26,11 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed + 2**32 if seed < 0 else seed + return seed % 2**32 if isinstance(seed, list): - return [s + 2**32 if s < 0 else s for s in seed] + return [s % 2**32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()]) + return np.array([s % 2**32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 84a0a7416..3743dc1ff 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -173,7 +173,7 @@ def use_layer_norm(self) -> bool: @property def validation_target_value(self) -> float: - return 0.123744 + return 0.123757 @property def test_target_value(self) -> float: @@ -191,23 +191,23 @@ def use_resnet(self) -> bool: @property def validation_target_value(self) -> float: - return 0.124027 + return 0.12415 @property def test_target_value(self) -> float: - return 0.126468 + return 0.12648 class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): @property def validation_target_value(self) -> float: - return 0.124286 + return 0.129657 @property def test_target_value(self) -> float: # Todo - return 0.126725 + return 0.131967 @property def embedding_init_multiplier(self) -> float: diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index c63ac3f7b..446267440 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -254,7 +254,7 @@ def use_layer_norm(self) -> bool: @property def validation_target_value(self) -> float: - return 0.123744 + return 0.123757 @property def test_target_value(self) -> float: @@ -272,23 +272,23 @@ def use_resnet(self) -> bool: @property def validation_target_value(self) -> float: - return 0.124027 + return 0.12415 @property def test_target_value(self) -> float: - return 0.126468 + return 0.12648 class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload): @property def validation_target_value(self) -> float: - return 0.124286 + return 0.129657 @property def test_target_value(self) -> float: # Todo - return 0.126725 + return 0.131967 @property def embedding_init_multiplier(self) -> float: diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index e4810e142..d8de214f5 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -272,11 +272,11 @@ def use_silu(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.22009 + return 0.75445 @property def test_target_value(self) -> float: - return 1 - 0.3426 + return 0.6323 class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): @@ -287,11 +287,11 @@ def use_gelu(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.22077 + return 0.76765 @property def test_target_value(self) -> float: - return 1 - 0.3402 + return 0.6519 class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): @@ -302,8 +302,8 @@ def bn_init_scale(self) -> float: @property def validation_target_value(self) -> float: - return 1 - 0.23474 + return 0.76526 @property def test_target_value(self) -> float: - return 1 - 0.3577 + return 0.6423 diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 5c7c6c7d2..3549911fa 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -326,11 +326,11 @@ def use_silu(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.22009 + return 0.75445 @property def test_target_value(self) -> float: - return 1 - 0.342 + return 0.6323 class ImagenetResNetGELUWorkload(ImagenetResNetWorkload): @@ -341,11 +341,11 @@ def use_gelu(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.22077 + return 0.76765 @property def test_target_value(self) -> float: - return 1 - 0.3402 + return 0.6519 class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload): @@ -356,8 +356,8 @@ def bn_init_scale(self) -> float: @property def validation_target_value(self) -> float: - return 1 - 0.23474 + return 0.76526 @property def test_target_value(self) -> float: - return 1 - 0.3577 + return 0.6423 diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index a54ee9b5e..2ad71ffd0 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -99,11 +99,11 @@ def use_glu(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.2233 + return 0.75738 @property def test_target_value(self) -> float: - return 1 - 0.3455 + return 0.6359 class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @@ -114,11 +114,11 @@ def use_post_layer_norm(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.24688 + return 0.75312 @property def test_target_value(self) -> float: - return 1 - 0.3714 + return 0.6286 class ImagenetVitMapWorkload(ImagenetVitWorkload): @@ -129,8 +129,8 @@ def use_map(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.22886 + return 0.77113 @property def test_target_value(self) -> float: - return 1 - 0.3477 + return 0.6523 diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 51c79b2d0..703d40b07 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -90,11 +90,11 @@ def use_glu(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.2233 + return 0.75738 @property def test_target_value(self) -> float: - return 1 - 0.3455 + return 0.6359 class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @@ -105,11 +105,11 @@ def use_post_layer_norm(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.24688 + return 0.75312 @property def test_target_value(self) -> float: - return 1 - 0.3714 + return 0.6286 class ImagenetVitMapWorkload(ImagenetVitWorkload): @@ -120,8 +120,8 @@ def use_map(self) -> bool: @property def validation_target_value(self) -> float: - return 1 - 0.22886 + return 0.77113 @property def test_target_value(self) -> float: - return 1 - 0.3477 + return 0.6523 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 805877e31..f4d1ab0f3 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -161,7 +161,7 @@ def _build_input_queue( batch_size=global_batch_size, shuffle=train, sampler=None, - num_workers=4 if train else self.eval_num_workers, + num_workers=4, prefetch_factor=10, pin_memory=False, drop_last=train, @@ -388,11 +388,11 @@ def attention_temperature(self) -> float: @property def validation_target_value(self) -> float: - return 0.082665 + return 0.109977 @property def test_target_value(self) -> float: - return 0.50168 + return 0.068065 class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): @@ -403,11 +403,11 @@ def use_post_layer_norm(self) -> bool: @property def validation_target_value(self) -> float: - return 0.085371 + return 0.09731 @property def test_target_value(self) -> float: - return 0.053096 + return 0.05996 class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): @@ -418,8 +418,8 @@ def use_gelu(self) -> bool: @property def validation_target_value(self) -> float: - return 0.077958 + return 0.094114 @property def test_target_value(self) -> float: - return 0.047643 + return 0.056629 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index fe3a1e179..502cb093e 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -93,7 +93,7 @@ def __init__(self, out_features=self.encoder_dim, bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) - self.dropout = nn.Dropout(p=self.input_dropout_rate) + self.dropout = nn.Dropout(p=self.input_dropout_rate, inplace=True) def forward(self, inputs, input_paddings): output_paddings = input_paddings @@ -195,7 +195,7 @@ def __init__(self, config: ConformerConfig): in_features=config.encoder_dim, out_features=config.encoder_dim * config.feed_forward_expansion_factor, bias=True) - self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate) + self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True) self.linear2 = nn.Linear( in_features=config.encoder_dim * config.feed_forward_expansion_factor, out_features=config.encoder_dim, @@ -206,7 +206,8 @@ def __init__(self, config: ConformerConfig): else: feed_forward_residual_dropout_rate = ( config.feed_forward_residual_dropout_rate) - self.dropout2 = nn.Dropout(p=feed_forward_residual_dropout_rate) + self.dropout2 = nn.Dropout( + p=feed_forward_residual_dropout_rate, inplace=True) def forward(self, inputs, padding_mask): inputs = self.ln(inputs) @@ -316,7 +317,7 @@ def __init__(self, config: ConformerConfig): attention_residual_dropout_rate = 0.1 else: attention_residual_dropout_rate = config.attention_residual_dropout_rate - self.dropout = nn.Dropout(p=attention_residual_dropout_rate) + self.dropout = nn.Dropout(p=attention_residual_dropout_rate, inplace=True) def forward(self, outputs, paddings): outputs = self.ln(outputs) @@ -407,7 +408,7 @@ def __init__(self, config): conv_residual_dropout_rate = 0.0 else: conv_residual_dropout_rate = config.conv_residual_dropout_rate - self.dropout = nn.Dropout(p=conv_residual_dropout_rate) + self.dropout = nn.Dropout(p=conv_residual_dropout_rate, inplace=True) def forward(self, inputs, input_paddings): inputs = self.ln(inputs) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 9b8e2d61c..155b30920 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -187,7 +187,7 @@ def _build_input_queue( batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=4, pin_memory=True, drop_last=is_train) @@ -354,11 +354,11 @@ def attention_temperature(self) -> float: @property def validation_target_value(self) -> float: - return 0.082665 + return 0.109977 @property def test_target_value(self) -> float: - return 0.050168 + return 0.068065 class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload): @@ -369,11 +369,11 @@ def use_post_layer_norm(self) -> bool: @property def validation_target_value(self) -> float: - return 0.085371 + return 0.09731 @property def test_target_value(self) -> float: - return 0.053096 + return 0.05996 class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload): @@ -384,8 +384,8 @@ def use_gelu(self) -> bool: @property def validation_target_value(self) -> float: - return 0.077958 + return 0.094114 @property def test_target_value(self) -> float: - return 0.047643 + return 0.056629 diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4489c0402..8473fac0f 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -109,11 +109,11 @@ def use_tanh(self) -> bool: @property def validation_target_value(self) -> float: - return 0.133449 + return 0.150883 @property def test_target_value(self) -> float: - return 0.079810 + return 0.098613 class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): @@ -124,11 +124,11 @@ def enable_residual_connections(self) -> bool: @property def validation_target_value(self) -> float: - return 0.105042 + return 0.131564 @property def test_target_value(self) -> float: - return 0.060388 + return 0.079297 class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload @@ -156,8 +156,8 @@ def time_mask_count(self) -> int: @property def validation_target_value(self) -> float: - return 0.131553 + return 0.14342 @property def test_target_value(self) -> float: - return 0.082442 + return 0.090976 diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 23d533aa1..626bac278 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -114,6 +114,14 @@ class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload): def use_tanh(self) -> bool: return True + @property + def validation_target_value(self) -> float: + return 0.150883 + + @property + def test_target_value(self) -> float: + return 0.098613 + class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): @@ -121,6 +129,14 @@ class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload): def enable_residual_connections(self) -> bool: return False + @property + def validation_target_value(self) -> float: + return 0.131564 + + @property + def test_target_value(self) -> float: + return 0.079297 + class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload ): @@ -144,3 +160,11 @@ def freq_mask_count(self) -> int: @property def time_mask_count(self) -> int: return 15 + + @property + def validation_target_value(self) -> float: + return 0.14342 + + @property + def test_target_value(self) -> float: + return 0.090976 diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index b10d4056d..c69965692 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -299,7 +299,7 @@ class WmtWorkloadPostLN(WmtWorkload): @property def validation_target_value(self) -> float: - return 30.2003 + return 30.0779 @property def test_target_value(self) -> float: @@ -315,11 +315,11 @@ class WmtWorkloadAttentionTemp(WmtWorkload): @property def validation_target_value(self) -> float: - return 30.0756 + return 29.8611 @property def test_target_value(self) -> float: - return 29.8094 + return 29.4143 @property def attention_temp(self) -> float: @@ -331,11 +331,11 @@ class WmtWorkloadGLUTanH(WmtWorkload): @property def validation_target_value(self) -> float: - return 30.0002 + return 29.6517 @property def test_target_value(self) -> float: - return 29.8139 + return 29.0515 @property def activation(self) -> str: diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 9f6d817f4..5ef09d278 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -355,7 +355,7 @@ class WmtWorkloadPostLN(WmtWorkload): @property def validation_target_value(self) -> float: - return 30.2003 + return 30.0779 @property def test_target_value(self) -> float: @@ -371,11 +371,11 @@ class WmtWorkloadAttentionTemp(WmtWorkload): @property def validation_target_value(self) -> float: - return 30.0756 + return 29.8611 @property def test_target_value(self) -> float: - return 29.8094 + return 229.4143 @property def attention_temp(self) -> float: @@ -387,11 +387,11 @@ class WmtWorkloadGLUTanH(WmtWorkload): @property def validation_target_value(self) -> float: - return 30.0002 + return 29.6517 @property def test_target_value(self) -> float: - return 29.8139 + return 29.0515 @property def activation(self) -> str: diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json index 13bf07b4b..22f3376b4 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json @@ -1,27 +1,27 @@ { "learning_rate": { "feasible_points": [ - 0.001308209823469072 + 0.0007852999990476642 ] }, "beta1": { "feasible_points": [ - 0.9731333693827139 + 0.6994142393023162 ] }, "beta2": { "feasible_points": [ - 0.9981232922116359 + 0.9918636824608852 ] }, "warmup_steps": { "feasible_points": [ - 9999 + 6000 ] }, "weight_decay": { "feasible_points": [ - 0.16375311233774334 + 0.07286322158086678 ] } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json index 13bf07b4b..ad200c01b 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json @@ -1,17 +1,17 @@ { "learning_rate": { "feasible_points": [ - 0.001308209823469072 + 0.000590120167916659 ] }, "beta1": { "feasible_points": [ - 0.9731333693827139 + 0.737199286155609 ] }, "beta2": { "feasible_points": [ - 0.9981232922116359 + 0.05919391544031072 ] }, "warmup_steps": { @@ -21,7 +21,7 @@ }, "weight_decay": { "feasible_points": [ - 0.16375311233774334 + 0.14128519778326312 ] } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json index 13bf07b4b..8297cf0ae 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json @@ -1,27 +1,27 @@ { "learning_rate": { "feasible_points": [ - 0.001308209823469072 + 0.0014446807792420305 ] }, "beta1": { "feasible_points": [ - 0.9731333693827139 + 0.7427148812902895 ] }, "beta2": { "feasible_points": [ - 0.9981232922116359 + 0.8993064520764248 ] }, "warmup_steps": { "feasible_points": [ - 9999 + 3000 ] }, "weight_decay": { "feasible_points": [ - 0.16375311233774334 + 0.06875136511682291 ] } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json index b31b711f7..e76a48325 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json @@ -1,27 +1,27 @@ { "learning_rate": { "feasible_points": [ - 0.0035278622506232458 + 0.0020162740358935045 ] }, "beta1": { "feasible_points": [ - 0.8192305396005781 + 0.9604907112078142 ] }, "beta2": { "feasible_points": [ - 0.495850879212151 + 0.8765457000160508 ] }, "warmup_steps": { "feasible_points": [ - 6000 + 3600 ] }, "weight_decay": { "feasible_points": [ - 0.04339748256184769 + 0.0006149579248633481 ] } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json index e20a2dae1..55f70f9fc 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json @@ -1,27 +1,27 @@ { "learning_rate": { "feasible_points": [ - 0.001308209823469072 + 0.0014446807792420305 ] }, "beta1": { "feasible_points": [ - 0.9731333693827139 + 0.7427148812902895 ] }, "beta2": { "feasible_points": [ - 0.9981232922116359 + 0.8993064520764248 ] }, "warmup_steps": { "feasible_points": [ - 6000 + 1800 ] }, "weight_decay": { "feasible_points": [ - 0.16375311233774334 + 0.06875136511682291 ] } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json index e0121cc26..e5f906688 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json @@ -1,27 +1,27 @@ { "learning_rate": { "feasible_points": [ - 0.003632312571224348 + 0.003604759885558324 ] }, "beta1": { "feasible_points": [ - 0.9980088784197237 + 0.9931094324430452 ] }, "beta2": { "feasible_points": [ - 0.9982275351621527 + 0.9976871843749077 ] }, "warmup_steps": { "feasible_points": [ - 6000 + 720 ] }, "weight_decay": { "feasible_points": [ - 0.2479797019098727 + 0.120077307855989 ] } } diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 8009dbc88..8ee271804 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -157,7 +157,7 @@ def get_workloads_time_to_target(submission, # For each workload get submission time get the submission times to target. for workload, group in submission.groupby('workload'): - validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) + validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload) # Check number of studies time_vals_per_study = [] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 48777c69e..0b768855e 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -48,8 +48,9 @@ FLAGS = flags.FLAGS -def get_summary_df(workload, workload_df): - validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) +def get_summary_df(workload, workload_df, include_test_split=False): + validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation') + is_minimized = performance_profile.check_if_minimized(validation_metric) target_op = operator.le if is_minimized else operator.ge best_op = min if is_minimized else max @@ -58,32 +59,58 @@ def get_summary_df(workload, workload_df): summary_df = pd.DataFrame() summary_df['workload'] = workload_df['workload'] summary_df['trial'] = workload_df['trial'].apply(lambda x: x[0]) - summary_df['target metric name'] = validation_metric - summary_df['target metric value'] = validation_target + summary_df['val target metric name'] = validation_metric + summary_df['val target metric value'] = validation_target - summary_df['target reached'] = workload_df[validation_metric].apply( + summary_df['val target reached'] = workload_df[validation_metric].apply( lambda x: target_op(x, validation_target)).apply(np.any) - summary_df['best metric value'] = workload_df[validation_metric].apply( + summary_df['best metric value on val'] = workload_df[validation_metric].apply( lambda x: best_op(x)) - workload_df['index best eval'] = workload_df[validation_metric].apply( + workload_df['index best eval on val'] = workload_df[validation_metric].apply( lambda x: idx_op(x)) - summary_df['time to best eval (s)'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval']], axis=1) - summary_df['time to target (s)'] = summary_df.apply( - lambda x: x['time to best eval (s)'] if x['target reached'] else np.inf, + summary_df['time to best eval on val (s)'] = workload_df.apply( + lambda x: x['accumulated_submission_time'][x['index best eval on val']], + axis=1) + summary_df['time to target on val (s)'] = summary_df.apply( + lambda x: x['time to best eval on val (s)'] + if x['val target reached'] else np.inf, axis=1) + # test metrics + if include_test_split: + test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test') + + summary_df['test target metric name'] = test_metric + summary_df['test target metric value'] = test_target + + summary_df['test target reached'] = workload_df[test_metric].apply( + lambda x: target_op(x, test_target)).apply(np.any) + summary_df['best metric value on test'] = workload_df[test_metric].apply( + lambda x: best_op(x)) + workload_df['index best eval on test'] = workload_df[test_metric].apply( + lambda x: idx_op(x)) + summary_df['time to best eval on test (s)'] = workload_df.apply( + lambda x: x['accumulated_submission_time'][x['index best eval on test'] + ], + axis=1) + summary_df['time to target on test (s)'] = summary_df.apply( + lambda x: x['time to best eval on test (s)'] + if x['test target reached'] else np.inf, + axis=1) + return summary_df -def print_submission_summary(df): +def print_submission_summary(df, include_test_split=True): dfs = [] for workload, group in df.groupby('workload'): - summary_df = get_summary_df(workload, group) + summary_df = get_summary_df( + workload, group, include_test_split=include_test_split) dfs.append(summary_df) df = pd.concat(dfs) logging.info('\n' + tabulate(df, headers='keys', tablefmt='psql')) + return df def main(_): @@ -93,7 +120,10 @@ def main(_): experiment_path = os.path.join(FLAGS.submission_directory, submission) df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df - print_submission_summary(df) + summary_df = print_submission_summary(df) + with open(os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), + 'w') as fout: + summary_df.to_csv(fout) if not FLAGS.strict: logging.warning( diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 722b197a4..0dd997ab9 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -174,6 +174,11 @@ def get_experiment_df(experiment_dir): study_dirs = os.listdir(experiment_dir) for study_dir in study_dirs: workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir)) + workload_dirs = [ + w for w in workload_dirs + if os.path.isdir(os.path.join(experiment_dir, study_dir, w)) + ] + print(workload_dirs) for workload in workload_dirs: data = { 'workload': workload, @@ -208,7 +213,7 @@ def get_experiment_df(experiment_dir): ## Get workload properties -def get_workload_validation_target(workload): +def get_workload_metrics_and_targets(workload, split='validation'): """Returns workload target metric name and value.""" workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1) framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2) @@ -225,6 +230,10 @@ def get_workload_validation_target(workload): workload_class_name=workload_metadata['workload_class_name'], workload_init_kwargs=workload_init_kwargs) metric_name = workload_obj.target_metric_name - validation_metric = f'validation/{metric_name}' - validation_target = workload_obj.validation_target_value - return validation_metric, validation_target + if split == 'validation': + metric = f'validation/{metric_name}' + target = workload_obj.validation_target_value + elif split == 'test': + metric = f'test/{metric_name}' + target = workload_obj.test_target_value + return metric, target diff --git a/setup.cfg b/setup.cfg index 0c986451b..321020ad9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -120,6 +120,7 @@ jax_core_deps = # upgrade jax. chex==0.1.7 ml_dtypes==0.2.0 + protobuf==4.25.3 # JAX CPU jax_cpu = diff --git a/submission_runner.py b/submission_runner.py index e9a3f7dba..40eb8cd58 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -207,6 +207,7 @@ def train_once( log_dir: Optional[str] = None, save_checkpoints: Optional[bool] = True ) -> Tuple[spec.Timing, Dict[str, Any]]: + _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) # Workload setup. diff --git a/utils/workload_config.json b/utils/target_setting_workload_config.json similarity index 96% rename from utils/workload_config.json rename to utils/target_setting_workload_config.json index bd67768ac..a8c050422 100644 --- a/utils/workload_config.json +++ b/utils/target_setting_workload_config.json @@ -123,25 +123,25 @@ "max_steps": 48000, "dataset": "librispeech", "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json" + "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json" }, "librispeech_deepspeech_no_resnet": { "max_steps": 48000, "dataset": "librispeech", "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json" + "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json" }, "librispeech_deepspeech_norm_and_spec_aug": { "max_steps": 48000, "dataset": "librispeech", "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json" + "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json" }, "librispeech_deepspeech_tanh": { "max_steps": 48000, "dataset": "librispeech", "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json" + "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json" }, "criteo1tb": { "max_steps": 10666, @@ -176,7 +176,7 @@ "librispeech_conformer_attention_temperature": { "max_steps": 80000, "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", + "submission_path": "reference_algorithms/target_setting_algorithms/jax_adamw.py", "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json" }, "librispeech_conformer_gelu": {