From a8187b80cb6389a7efe629c2f0d82cdc6f540072 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 13 Jan 2024 14:46:52 +0000 Subject: [PATCH 01/71] Add pass/fail thresholds to traindiffs test --- tests/test_traindiffs.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index a1b64a573..d7ef3d3ae 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -10,6 +10,7 @@ from absl import flags from absl.testing import absltest +from numpy import allclose FLAGS = flags.FLAGS @@ -81,6 +82,17 @@ def test_workload(self): print(header) print('=' * len(header)) for i in range(NUM_TRAIN_STEPS): + rtol = 1e-1 if workload == 'librispeech_deepspeech' else 5e-3 + self.assertTrue(allclose(jax_results['eval_results'][i][k], + pyt_results['eval_results'][i][k], + rtol=rtol)) + self.assertTrue(allclose(jax_results['scalars'][i]['grad_norm'], + pyt_results['scalars'][i]['grad_norm'], + rtol=rtol)) + self.assertTrue(allclose(jax_results['scalars'][i]['loss'], + pyt_results['scalars'][i]['loss'], + rtol=rtol)) + row = map(lambda x: str(round(x, 5)), [ jax_results['eval_results'][i][k], From 2373e1599b9052f7d4cc3522d88c5c3a6aa5ceb2 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 13 Jan 2024 14:47:20 +0000 Subject: [PATCH 02/71] Add traindiffs_test option to docker startup script --- docker/scripts/startup.sh | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 30cb6b36b..0d9ad1b1e 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -14,8 +14,8 @@ function usage() { $0 [--dataset dataset] [--framework framework] [--submission_path submission_path] [--tuning_search_space tuning_search_space] [--experiment_name experiment_name] [--workload workload] [--max_global_steps max_global_steps] [--rsync_data rsync_data] - [--internal_contributor true] - + [--internal_contributor true] [--traindiffs_test false] + Options: -d | --dataset: Can be imagenet, criteo1tb, ogbg, fastmri, wmt, librispeech. -f | --framework: Can be jax or pytorch. @@ -34,11 +34,13 @@ function usage() { from internal GCP bucket. -i | --internal_contributor: If true, allow rsync of data and transfer of experiment results with GCP project. + --traindiffs_test: If true, ignore all other options and run the traindiffs test. USAGE exit 1 } # Defaults +TEST="false" INTERNAL_CONTRIBUTOR_MODE="false" HOME_DIR="" RSYNC_DATA="true" @@ -47,7 +49,11 @@ SAVE_CHECKPOINTS="true" # Pass flag while [ "$1" != "" ]; do - case $1 in + case $1 in + --traindiffs_test) + shift + TEST=$1 + ;; -d | --dataset) shift DATASET=$1 @@ -106,8 +112,15 @@ while [ "$1" != "" ]; do ;; esac shift -done - +done + +if [[ ${TEST} == "true" ]]; then + cd algorithmic-efficiency + COMMAND="python3 tests/test_traindiffs.py" + echo $COMMAND + eval $COMMAND + exit +fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ From d1da9c7651bc4edf311b3b28ba7c3f232ad3ef5c Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 13 Jan 2024 16:01:57 +0100 Subject: [PATCH 03/71] Rename PytWorkload to PyTorchWorkload --- tests/modeldiffs/criteo1tb/compare.py | 4 ++-- .../criteo1tb_embed_init/compare.py | 4 ++-- .../modeldiffs/criteo1tb_layernorm/compare.py | 4 ++-- tests/modeldiffs/criteo1tb_resnet/compare.py | 4 ++-- tests/modeldiffs/fastmri/compare.py | 4 ++-- tests/modeldiffs/fastmri_layernorm/compare.py | 4 ++-- .../modeldiffs/fastmri_model_size/compare.py | 4 ++-- tests/modeldiffs/fastmri_tanh/compare.py | 4 ++-- tests/modeldiffs/imagenet_resnet/compare.py | 4 ++-- .../imagenet_resnet/gelu_compare.py | 4 ++-- .../imagenet_resnet/silu_compare.py | 4 ++-- tests/modeldiffs/imagenet_vit/compare.py | 4 ++-- tests/modeldiffs/imagenet_vit/glu_compare.py | 4 ++-- .../imagenet_vit/post_ln_compare.py | 4 ++-- .../librispeech_conformer/compare.py | 4 ++-- .../compare.py | 4 ++-- .../librispeech_conformer_gelu/compare.py | 4 ++-- .../compare.py | 4 ++-- .../librispeech_deepspeech/compare.py | 4 ++-- tests/modeldiffs/wmt/compare.py | 4 ++-- .../modeldiffs/wmt_attention_temp/compare.py | 4 ++-- tests/modeldiffs/wmt_glu_tanh/compare.py | 4 ++-- tests/modeldiffs/wmt_post_ln/compare.py | 4 ++-- tests/reference_algorithm_tests.py | 2 +- tests/test_traindiffs.py | 24 ++++++++++++------- 25 files changed, 62 insertions(+), 56 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 9a95f3656..adbade983 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallWorkload as PytWorkload + Criteo1TbDlrmSmallWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -51,7 +51,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index 719484037..0748e2d71 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallEmbedInitWorkload as PytWorkload + Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -50,7 +50,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 3fc2a750a..0a6e5c5ac 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallLayerNormWorkload as PytWorkload + Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index b9dbbc80e..288442594 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallResNetWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallResNetWorkload as PytWorkload + Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index 6780ff91e..56b74b32d 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRIWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIWorkload as PytWorkload + FastMRIWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 4be086da3..23ccf26d7 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRILayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRILayerNormWorkload as PytWorkload + FastMRILayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index 60d846b6f..b61516c29 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRIModelSizeWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIModelSizeWorkload as PytWorkload + FastMRIModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 47bad372a..0f455387c 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRITanhWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRITanhWorkload as PytWorkload + FastMRITanhWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index 2fc721ab0..fb730f1bf 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetWorkload as PytWorkload + ImagenetResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -72,7 +72,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 8c3899076..6c8adbec2 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetGELUWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetGELUWorkload as PytWorkload + ImagenetResNetGELUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -19,7 +19,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index ee74e7bc9..7668cdbd9 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetSiLUWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetSiLUWorkload as PytWorkload + ImagenetResNetSiLUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -19,7 +19,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index bf7d6dfa5..ebf39e4c3 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitWorkload as PytWorkload + ImagenetVitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -75,7 +75,7 @@ def key_transform(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit/glu_compare.py b/tests/modeldiffs/imagenet_vit/glu_compare.py index 444f1230a..2c0aa546d 100644 --- a/tests/modeldiffs/imagenet_vit/glu_compare.py +++ b/tests/modeldiffs/imagenet_vit/glu_compare.py @@ -13,7 +13,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitGluWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitGluWorkload as PytWorkload + ImagenetVitGluWorkload as PyTorchWorkload sd_transform = None @@ -21,7 +21,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit/post_ln_compare.py b/tests/modeldiffs/imagenet_vit/post_ln_compare.py index 8bf0bef7e..0883b5676 100644 --- a/tests/modeldiffs/imagenet_vit/post_ln_compare.py +++ b/tests/modeldiffs/imagenet_vit/post_ln_compare.py @@ -13,7 +13,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetViTPostLNWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetViTPostLNWorkload as PytWorkload + ImagenetViTPostLNWorkload as PyTorchWorkload sd_transform = None @@ -21,7 +21,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index d414001dd..cfe6c7381 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload as PytWorkload + LibriSpeechConformerWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index 64612fbf0..8480fca02 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerAttentionTemperatureWorkload as PytWorkload + LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index 892040b57..caa9b09b9 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerGeluWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerGeluWorkload as PytWorkload + LibriSpeechConformerGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index 784fceb60..1a94d3c77 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerLayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerLayerNormWorkload as PytWorkload + LibriSpeechConformerLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index 12b79a517..edcc3ba87 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechWorkload as PytWorkload + LibriSpeechDeepSpeechWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -83,7 +83,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 8d0ee8411..41fc5ee17 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkload as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkload as PytWorkload + WmtWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index b50abd3ca..92ce4eb44 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadAttentionTemp as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadAttentionTemp as PytWorkload + WmtWorkloadAttentionTemp as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 1322ad0a0..b8d860479 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadGLUTanH as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadGLUTanH as PytWorkload + WmtWorkloadGLUTanH as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index bfd701736..3f5469d8d 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadPostLN as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadPostLN as PytWorkload + WmtWorkloadPostLN as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 5c43b233b..74c06e180 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -197,7 +197,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): f'tests.modeldiffs.{workload_name}.compare') jax_params, model_state, _ = diff_utils.torch2jax( jax_workload=super(), - pytorch_workload=compare_module.PytWorkload(**self.init_kwargs), + pytorch_workload=compare_module.PyTorchWorkload(**self.init_kwargs), key_transform=compare_module.key_transform, sd_transform=compare_module.sd_transform) return (FrozenDict(**jax_utils.replicate(jax_params)), diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index d7ef3d3ae..663cf3de4 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -83,15 +83,21 @@ def test_workload(self): print('=' * len(header)) for i in range(NUM_TRAIN_STEPS): rtol = 1e-1 if workload == 'librispeech_deepspeech' else 5e-3 - self.assertTrue(allclose(jax_results['eval_results'][i][k], - pyt_results['eval_results'][i][k], - rtol=rtol)) - self.assertTrue(allclose(jax_results['scalars'][i]['grad_norm'], - pyt_results['scalars'][i]['grad_norm'], - rtol=rtol)) - self.assertTrue(allclose(jax_results['scalars'][i]['loss'], - pyt_results['scalars'][i]['loss'], - rtol=rtol)) + self.assertTrue( + allclose( + jax_results['eval_results'][i][k], + pyt_results['eval_results'][i][k], + rtol=rtol)) + self.assertTrue( + allclose( + jax_results['scalars'][i]['grad_norm'], + pyt_results['scalars'][i]['grad_norm'], + rtol=rtol)) + self.assertTrue( + allclose( + jax_results['scalars'][i]['loss'], + pyt_results['scalars'][i]['loss'], + rtol=rtol)) row = map(lambda x: str(round(x, 5)), [ From 6a5d63a7868f622215cb0a68205beb89fad62bd9 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 13 Jan 2024 16:06:11 +0100 Subject: [PATCH 04/71] Add traindiffs tests to workflows (self-hosted) --- .github/workflows/traindiffs_tests.yml | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/traindiffs_tests.yml diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml new file mode 100644 index 000000000..5bb84c867 --- /dev/null +++ b/.github/workflows/traindiffs_tests.yml @@ -0,0 +1,32 @@ +name: Containerized training differences tests between Jax and PyTorch + +on: + pull_request: + branches: + - 'main' + +jobs: + build_and_push_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker image + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=both + IMAGE_NAME="algoperf_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + traindiffs_tests: + runs-on: self-hosted + needs: build_and_push_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized traindiffs test + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} algoperf_${{ github.head_ref || github.ref_name }} --traindiffs_test true From 1683ba37ea41faa69ad83a90d3cb044daad75004 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 18 Jan 2024 18:41:12 +0000 Subject: [PATCH 05/71] add variant scoring conditions --- scoring/performance_profile.py | 85 ++++++++++++++++++++++++++++------ scoring/score_submission.py | 11 ++++- 2 files changed, 81 insertions(+), 15 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 84788c7ae..9322dfaa7 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -36,16 +36,19 @@ import pandas as pd import algorithmic_efficiency.workloads.workloads as workloads_registry +from algorithmic_efficiency.workloads.workloads import get_base_workload_name from scoring import scoring_utils WORKLOADS = workloads_registry.WORKLOADS +BASE_WORKLOADS = workloads_registry.BASE_WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' # These global variables have to be set according to the current set of # workloads and rules for the scoring to be correct. # We do not use the workload registry since it contains test and development # workloads as well. -NUM_WORKLOADS = 8 +NUM_BASE_WORKLOADS = 8 +NUM_VARIANT_WORKLOADS = 6 NUM_TRIALS = 5 MIN_EVAL_METRICS = [ @@ -152,16 +155,17 @@ def get_index_that_reaches_target(workload_df, def get_times_for_submission(submission, - submission_tag, + submission_name, time_col='global_step', verbosity=1, - self_tuning_ruleset=False): + self_tuning_ruleset=False, + strict=False): """Get times to target for each workload in a submission. Args: submission: A DataFrame containing one row for each trial in each workload for a given submission. - submission_tag: Globally unique identified for a submission. + submission_name: Globally unique identified for a submission. time_col: A string indicating which column to use for time. verbosity: Debug level of information; choice of (1, 2, 3). @@ -169,16 +173,23 @@ def get_times_for_submission(submission, DataFrame with columns `submission`, `workload`, and time_col. """ workloads = [] - submission_name = submission_tag.split('.')[1] num_workloads = len(submission.groupby('workload')) - if num_workloads != NUM_WORKLOADS: - logging.warning(f'Expecting {NUM_WORKLOADS} workloads ' - f'but found {num_workloads} workloads.') + if num_workloads != NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS: + if strict: + raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') + logging.warning( + f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' + f'but found {num_workloads} workloads.') for workload, group in submission.groupby('workload'): num_trials = len(group) if num_trials != NUM_TRIALS and not self_tuning_ruleset: - logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') + if strict: + raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') + else: + logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) trial_idx, time_idx = get_index_that_reaches_target( @@ -202,12 +213,30 @@ def get_times_for_submission(submission, print(f' - {key}: {val}') else: print('Submission did not reach target') + df = pd.DataFrame.from_records(workloads) + print(df) df = df.pivot(index='submission', columns='workload', values=time_col) + print(time_col) return df +def variant_criteria_filter(base_workload, variant_workload): + + def filter(x): + try: + if x[variant_workload] == np.inf: + return np.inf + else: + return x[base_workload] + except KeyError as e: + print(x.keys()) + raise e + + return filter + + def compute_performance_profiles(results, time_col='global_step', min_tau=1.0, @@ -215,7 +244,9 @@ def compute_performance_profiles(results, reference_submission_tag=None, num_points=100, scale='linear', - verbosity=0): + verbosity=0, + strict=False, + self_tuning_ruleset=False): """Compute performance profiles for a set of submission by some time column. Args: @@ -247,9 +278,37 @@ def compute_performance_profiles(results, f'\nComputing performance profile with respect to `{time_col}` for ' f'{submission_tag}') dfs.append( - get_times_for_submission(result, submission_tag, time_col, verbosity)) + get_times_for_submission(result, + submission_tag, + time_col, + verbosity, + self_tuning_ruleset, + strict)) df = pd.concat(dfs) + # if strict: + + # Set score to inf if not within 4x of fastest submission + best_scores = df.min(axis=0) + df[df.apply(lambda x: x > 4 * best_scores, axis=1)] = np.inf + + # For each held-out workload if variant target was not hit set submission to inf + framework = None + for workload in df.keys(): + # Check if this is a variant + framework = workload.split('_')[-1] + workload_ = workload.split(f'_{framework}')[0] + if workload_ not in BASE_WORKLOADS: + # If variants do not have finite score set base_workload score to inf + base_workload = get_base_workload_name(workload_) + df[base_workload] = df.apply( + variant_criteria_filter(base_workload + f'_{framework}', workload), + axis=1) + + base_workloads = [w + f'_{framework}' for w in BASE_WORKLOADS] + df = df[base_workloads] + print(df) + if verbosity > 0: logging.info('\n`{time_col}` to reach target:') with pd.option_context('display.max_rows', @@ -288,7 +347,7 @@ def compute_performance_profiles(results, np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0) def rho(r, tau): - return (r <= tau).sum(axis=1) / NUM_WORKLOADS + return (r <= tau).sum(axis=1) / NUM_BASE_WORKLOADS perf_df = pd.concat([rho(df, tau) for tau in points], axis=1) diff --git a/scoring/score_submission.py b/scoring/score_submission.py index 0dd84ff55..e0a32777f 100644 --- a/scoring/score_submission.py +++ b/scoring/score_submission.py @@ -22,6 +22,11 @@ flags.DEFINE_boolean('compute_performance_profiles', False, 'Whether or not to compute the performance profiles.') +flags.DEFINE_boolean( + 'strict', + False, + 'Whether to enforce scoring criteria on variant' + 'performance and on 5-trial median performance') FLAGS = flags.FLAGS @@ -57,6 +62,7 @@ def main(_): results = { FLAGS.submission_tag: df, } + print(df) dfs = [] for workload, group in df.groupby('workload'): @@ -64,7 +70,7 @@ def main(_): dfs.append(summary_df) df = pd.concat(dfs) - logging.info(tabulate(df, headers='keys', tablefmt='psql')) + logging.info('\n' + tabulate(df, headers='keys', tablefmt='psql')) if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( @@ -75,7 +81,8 @@ def main(_): reference_submission_tag=None, num_points=100, scale='linear', - verbosity=0) + verbosity=0, + strict=FLAGS.strict) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( From 370687deb8abfcff8e9393755d6f80bf0f5d2d2b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 18 Jan 2024 18:43:28 +0000 Subject: [PATCH 06/71] add flag for self-tuning rulset --- scoring/score_submission.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scoring/score_submission.py b/scoring/score_submission.py index e0a32777f..eafb41ce6 100644 --- a/scoring/score_submission.py +++ b/scoring/score_submission.py @@ -27,6 +27,11 @@ False, 'Whether to enforce scoring criteria on variant' 'performance and on 5-trial median performance') +flags.DEFINE_boolean( + 'self_tuning_ruleset', + False, + 'Whether to score on self-tuning ruleset or externally tuned ruleset' +) FLAGS = flags.FLAGS @@ -82,6 +87,7 @@ def main(_): num_points=100, scale='linear', verbosity=0, + self_tuning_ruleset=FLAGS.self_tuning_ruleset, strict=FLAGS.strict) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) From 2128ce8bf7fd600e351f951f4fec5493414f7202 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 18 Jan 2024 18:51:00 +0000 Subject: [PATCH 07/71] score group of submissions --- scoring/score_submission.py | 103 ------------------------------------ 1 file changed, 103 deletions(-) delete mode 100644 scoring/score_submission.py diff --git a/scoring/score_submission.py b/scoring/score_submission.py deleted file mode 100644 index eafb41ce6..000000000 --- a/scoring/score_submission.py +++ /dev/null @@ -1,103 +0,0 @@ -import operator -import os - -from absl import app -from absl import flags -from absl import logging -import numpy as np -import pandas as pd -import scoring_utils -from tabulate import tabulate - -from scoring import performance_profile - -flags.DEFINE_string( - 'experiment_path', - None, - 'Path to experiment directory containing workload directories.') -flags.DEFINE_string('submission_tag', 'my.submission', 'Submission tag.') -flags.DEFINE_string('output_dir', - 'scoring_results', - 'Path to save performance profile table and plot.') -flags.DEFINE_boolean('compute_performance_profiles', - False, - 'Whether or not to compute the performance profiles.') -flags.DEFINE_boolean( - 'strict', - False, - 'Whether to enforce scoring criteria on variant' - 'performance and on 5-trial median performance') -flags.DEFINE_boolean( - 'self_tuning_ruleset', - False, - 'Whether to score on self-tuning ruleset or externally tuned ruleset' -) -FLAGS = flags.FLAGS - - -def get_summary_df(workload, workload_df): - validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) - is_minimized = performance_profile.check_if_minimized(validation_metric) - target_op = operator.le if is_minimized else operator.ge - best_op = min if is_minimized else max - idx_op = np.argmin if is_minimized else np.argmax - - summary_df = pd.DataFrame() - summary_df['workload'] = workload_df['workload'] - summary_df['trial'] = workload_df['trial'] - summary_df['target metric name'] = validation_metric - summary_df['target metric value'] = validation_target - - summary_df['target reached'] = workload_df[validation_metric].apply( - lambda x: target_op(x, validation_target)).apply(np.any) - summary_df['best target'] = workload_df[validation_metric].apply( - lambda x: best_op(x)) - workload_df['index best eval'] = workload_df[validation_metric].apply( - lambda x: idx_op(x)) - summary_df['submission time'] = workload_df.apply( - lambda x: x['accumulated_submission_time'][x['index best eval']], axis=1) - summary_df['score'] = summary_df.apply( - lambda x: x['submission time'] if x['target reached'] else np.inf, axis=1) - - return summary_df - - -def main(_): - df = scoring_utils.get_experiment_df(FLAGS.experiment_path) - results = { - FLAGS.submission_tag: df, - } - print(df) - - dfs = [] - for workload, group in df.groupby('workload'): - summary_df = get_summary_df(workload, group) - dfs.append(summary_df) - - df = pd.concat(dfs) - logging.info('\n' + tabulate(df, headers='keys', tablefmt='psql')) - - if FLAGS.compute_performance_profiles: - performance_profile_df = performance_profile.compute_performance_profiles( - results, - time_col='score', - min_tau=1.0, - max_tau=None, - reference_submission_tag=None, - num_points=100, - scale='linear', - verbosity=0, - self_tuning_ruleset=FLAGS.self_tuning_ruleset, - strict=FLAGS.strict) - if not os.path.exists(FLAGS.output_dir): - os.mkdir(FLAGS.output_dir) - performance_profile.plot_performance_profiles( - performance_profile_df, 'score', save_dir=FLAGS.output_dir) - perf_df = tabulate( - performance_profile_df.T, headers='keys', tablefmt='psql') - logging.info(f'Performance profile:\n {perf_df}') - - -if __name__ == '__main__': - flags.mark_flag_as_required('experiment_path') - app.run(main) From d43ccf4782637c51643d470d865abf203c29665d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 19 Jan 2024 23:11:25 +0000 Subject: [PATCH 08/71] correct max number of steps --- scoring/run_workloads.py | 141 +++++++++++++++++++++++++++++++++++ scoring/score_submissions.py | 104 ++++++++++++++++++++++++++ 2 files changed, 245 insertions(+) create mode 100644 scoring/run_workloads.py create mode 100644 scoring/score_submissions.py diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py new file mode 100644 index 000000000..1804d157d --- /dev/null +++ b/scoring/run_workloads.py @@ -0,0 +1,141 @@ +""" +Example Usage: +python run_all_workloads.py --framework jax \ +--experiment_basename my_first_experiment \ +--docker_image_url \ +--tag \ +--run_percentage 10 \ +--submission_path \ +--tuning_search_space +""" + +from absl import flags +from absl import app +import os +import docker +import time + + +flags.DEFINE_string('tag', None, 'Optional Docker image tag') +flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') +flags.DEFINE_integer('run_percentage', 100, 'Percentage of max num steps to run for.') +flags.DEFINE_string('experiment_basename', 'my_experiment', 'Name of top sub directory in experiment dir.') +flags.DEFINE_boolean('rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.') +flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') +flags.DEFINE_string('submission_path', + 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', + 'Path to reference submission.') +flags.DEFINE_string('tuning_search_space', + 'prize_qualification_baselines/external_tuning/tuning_search_space.json', + 'Path to tuning search space.') +flags.DEFINE_string('framework', + None, + 'Can be either PyTorch or JAX.') +flags.DEFINE_boolean('dry_run', False, 'Whether or not to actually run the command') + + +FLAGS = flags.FLAGS + + +DATASETS = ['imagenet', + 'fastmri', + 'ogbg', + 'wmt', + 'librispeech', + 'criteo1tb'] + +WORKLOADS = { + 'imagenet_resnet': {'max_steps': 186_666, + 'dataset': 'imagenet'}, + 'imagenet_vit': {'max_steps': 186_666, + 'dataset': 'imagenet'}, + 'fastmri': {'max_steps': 36_189, + 'dataset': 'fastmri'}, + 'ogbg': {'max_steps': 80_000, + 'dataset': 'ogbg'}, + 'wmt': {'max_steps': 133_333, + 'dataset': 'wmt'}, + 'librispeech_deepspeech': {'max_steps': 48_000, + 'dataset': 'librispeech'}, + 'criteo1tb': {'max_steps': 10_666, + 'dataset': 'criteo1tb'}, + 'librispeech_conformer': {'max_steps': 80_000, + 'dataset': 'librispeech'}, + } + +def container_running(): + docker_client = docker.from_env() + containers = docker_client.containers.list() + if len(containers) == 0: + return False + else: + return True + +def wait_until_container_not_running(sleep_interval=5*60): + while container_running(): + time.sleep(sleep_interval) + return + +def main(_): + framework = FLAGS.framework + algorithm = FLAGS.algorithm + tag = f':{FLAGS.tag}' if FLAGS.tag is not None else '' + run_fraction = FLAGS.run_percentage/100. + experiment_basename=FLAGS.experiment_basename + rsync_data = 'true' if FLAGS.rsync_data else 'false' + docker_image_url = FLAGS.docker_image_url + submission_path = FLAGS.submisison_path + tuning_search_space = FLAGS.tuning_search_space + + # For each runnable workload check if there are any containers running and if not launch next container command + for workload in WORKLOADS.keys(): + wait_until_container_not_running() + os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + print('='*100) + dataset = WORKLOADS[workload]['dataset'] + max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) + experiment_name = f'{experiment_basename}/{algorithm}' + mount_repo_flag = '' + if FLAGS.local: + mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' + command = ('docker run -t -d -v $HOME/data/:/data/ ' + '-v $HOME/experiment_runs/:/experiment_runs ' + '-v $HOME/experiment_runs/logs:/logs ' + f'{mount_repo_flag}' + '--gpus all --ipc=host ' + f'{docker_image_url}{tag} ' + f'-d {dataset} ' + f'-f {framework} ' + f'-s {submission_path} ' + f'-w {workload} ' + f'-t {tuning_search_space} ' + f'-e {experiment_name} ' + f'-m {max_steps} ' + '-c false ' + '-o true ' + f'-r {rsync_data} ' + '-i true ') + if not FLAGS.dry_run: + print('Running docker container command') + print('Container ID: ') + return_code = os.system(command) + else: + return_code = 0 + if return_code == 0: + print(f'SUCCESS: container for {framework} {workload} {algorithm} launched successfully') + print(f'Command: {command}') + print(f'Results will be logged to {experiment_name}') + else: + print(f'Failed: container for {framework} {workload} {algorithm} failed with exit code {return_code}.') + print(f'Command: {command}') + wait_until_container_not_running() + os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + + print('='*100) + + +if __name__ == '__main__': + flags.mark_flag_as_required('framework') + flags.mark_flag_as_required() + + app.run(main) \ No newline at end of file diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py new file mode 100644 index 000000000..13a0dc9b2 --- /dev/null +++ b/scoring/score_submissions.py @@ -0,0 +1,104 @@ +import operator +import os + +from absl import app +from absl import flags +from absl import logging +import numpy as np +import pandas as pd +import scoring_utils +from tabulate import tabulate + +from scoring import performance_profile + +flags.DEFINE_string( + 'submission_directory, + None, + 'Path to submission directory containing experiment directories.') +flags.DEFINE_string('output_dir', + 'scoring_results', + 'Path to save performance profile table and plot.') +flags.DEFINE_boolean('compute_performance_profiles', + False, + 'Whether or not to compute the performance profiles.') +flags.DEFINE_boolean( + 'strict', + False, + 'Whether to enforce scoring criteria on variant' + 'performance and on 5-trial median performance') +flags.DEFINE_boolean( + 'self_tuning_ruleset', + False, + 'Whether to score on self-tuning ruleset or externally tuned ruleset' +) +FLAGS = flags.FLAGS + + +def get_summary_df(workload, workload_df): + validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) + is_minimized = performance_profile.check_if_minimized(validation_metric) + target_op = operator.le if is_minimized else operator.ge + best_op = min if is_minimized else max + idx_op = np.argmin if is_minimized else np.argmax + + summary_df = pd.DataFrame() + summary_df['workload'] = workload_df['workload'] + summary_df['trial'] = workload_df['trial'] + summary_df['target metric name'] = validation_metric + summary_df['target metric value'] = validation_target + + summary_df['target reached'] = workload_df[validation_metric].apply( + lambda x: target_op(x, validation_target)).apply(np.any) + summary_df['best target'] = workload_df[validation_metric].apply( + lambda x: best_op(x)) + workload_df['index best eval'] = workload_df[validation_metric].apply( + lambda x: idx_op(x)) + summary_df['submission time'] = workload_df.apply( + lambda x: x['accumulated_submission_time'][x['index best eval']], axis=1) + summary_df['score'] = summary_df.apply( + lambda x: x['submission time'] if x['target reached'] else np.inf, axis=1) + + return summary_df + +def print_submission_summary(df): + dfs = [] + for workload, group in df.groupby('workload'): + summary_df = get_summary_df(workload, group) + dfs.append(summary_df) + + df = pd.concat(dfs) + logging.info('\n' + tabulate(df, headers='keys', tablefmt='psql')) + + +def main(_): + results = {} + + for submission in os.path.listdir(FLAGS.submission_directory): + df = scoring_utils.get_experiment_df(FLAGS.experiment_path) + results[submission] = df + print_submission_summary(df) + + if FLAGS.compute_performance_profiles: + performance_profile_df = performance_profile.compute_performance_profiles( + results, + time_col='score', + min_tau=1.0, + max_tau=None, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0, + self_tuning_ruleset=FLAGS.self_tuning_ruleset, + strict=FLAGS.strict) + if not os.path.exists(FLAGS.output_dir): + os.mkdir(FLAGS.output_dir) + performance_profile.plot_performance_profiles( + performance_profile_df, 'score', save_dir=FLAGS.output_dir) + perf_df = tabulate( + performance_profile_df.T, headers='keys', tablefmt='psql') + logging.info(f'Performance profile:\n {perf_df}') + + +if __name__ == '__main__': + flags.mark_flag_as_required('experiment_path') + app.run(main) From fb814362e5783441b6cf64dfd090f6626bb5cf0e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 19 Jan 2024 23:34:27 +0000 Subject: [PATCH 09/71] add heldout workloads" --- scoring/run_workloads.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 1804d157d..4df0c50ba 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -29,7 +29,7 @@ 'prize_qualification_baselines/external_tuning/tuning_search_space.json', 'Path to tuning search space.') flags.DEFINE_string('framework', - None, + 'jax', 'Can be either PyTorch or JAX.') flags.DEFINE_boolean('dry_run', False, 'Whether or not to actually run the command') @@ -63,6 +63,19 @@ 'dataset': 'librispeech'}, } + +HELDOUT_WORKLOADS = { + 'librispeech': ['librispeech_conformer_attention_temperature', 'librispeech_conformer_layernorm', + 'librispeech_conformer_gelu'], + 'imagenet': ['imagenet_resnet_silu', 'imagenet_resnet_gelu', 'imagenet_resnet_large_bn_init', + 'imagenet_vit_gelu', 'imagenet_vit_post_ln', 'imagenet_vit_map' + ], + 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], + 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], + 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'] + 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] +} + def container_running(): docker_client = docker.from_env() containers = docker_client.containers.list() @@ -135,7 +148,5 @@ def main(_): if __name__ == '__main__': - flags.mark_flag_as_required('framework') - flags.mark_flag_as_required() app.run(main) \ No newline at end of file From 1ea2282f7d82b39363dfed32aef6af49f40dd130 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 23 Jan 2024 18:49:26 +0000 Subject: [PATCH 10/71] add trial args to docker startup.sh" --- docker/scripts/startup.sh | 18 ++++++++++++++++++ scoring/run_workloads.py | 24 +++++++++--------------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 30cb6b36b..2bd8abf33 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -44,6 +44,9 @@ HOME_DIR="" RSYNC_DATA="true" OVERWRITE="false" SAVE_CHECKPOINTS="true" +NUM_TUNING_TRIALS="1" +HPARAM_START_INDEX="None" +HPARAM_END_INDEX="None" # Pass flag while [ "$1" != "" ]; do @@ -100,6 +103,18 @@ while [ "$1" != "" ]; do shift HOME_DIR=$1 ;; + --num_tuning_trials) + shift + NUM_TUNING_TRIALS=$1 + ;; + --hparam_start_index) + shift + HPARAM_START_INDEX=$1 + ;; + --hparam_end_index) + shift + HPARAM_END_INDEX=$1 + ;; *) usage exit 1 @@ -204,6 +219,9 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --experiment_name=${EXPERIMENT_NAME} \ --overwrite=${OVERWRITE} \ --save_checkpoints=${SAVE_CHECKPOINTS} \ + --num_tuning_trials={NUM_TUNING_TRIALS} \ + --hparam_start_index={HPARAM_START_INDEX} \ + --hparam_end_index={HPARAM_END_INDEX} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ ${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}" diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 4df0c50ba..dff92aa86 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -1,7 +1,7 @@ """ Example Usage: -python run_all_workloads.py --framework jax \ ---experiment_basename my_first_experiment \ +python run_workloads.py --framework jax \ +--experiment_name my_first_experiment \ --docker_image_url \ --tag \ --run_percentage 10 \ @@ -16,10 +16,9 @@ import time -flags.DEFINE_string('tag', None, 'Optional Docker image tag') flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') flags.DEFINE_integer('run_percentage', 100, 'Percentage of max num steps to run for.') -flags.DEFINE_string('experiment_basename', 'my_experiment', 'Name of top sub directory in experiment dir.') +flags.DEFINE_string('experiment_name', 'my_experiment', 'Name of top sub directory in experiment dir.') flags.DEFINE_boolean('rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.') flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') flags.DEFINE_string('submission_path', @@ -72,7 +71,7 @@ ], 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], - 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'] + 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] } @@ -91,13 +90,10 @@ def wait_until_container_not_running(sleep_interval=5*60): def main(_): framework = FLAGS.framework - algorithm = FLAGS.algorithm - tag = f':{FLAGS.tag}' if FLAGS.tag is not None else '' run_fraction = FLAGS.run_percentage/100. - experiment_basename=FLAGS.experiment_basename - rsync_data = 'true' if FLAGS.rsync_data else 'false' + experiment_name=FLAGS.experiment_name docker_image_url = FLAGS.docker_image_url - submission_path = FLAGS.submisison_path + submission_path = FLAGS.submission_path tuning_search_space = FLAGS.tuning_search_space # For each runnable workload check if there are any containers running and if not launch next container command @@ -107,7 +103,6 @@ def main(_): print('='*100) dataset = WORKLOADS[workload]['dataset'] max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) - experiment_name = f'{experiment_basename}/{algorithm}' mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' @@ -116,7 +111,7 @@ def main(_): '-v $HOME/experiment_runs/logs:/logs ' f'{mount_repo_flag}' '--gpus all --ipc=host ' - f'{docker_image_url}{tag} ' + f'{docker_image_url} ' f'-d {dataset} ' f'-f {framework} ' f'-s {submission_path} ' @@ -126,7 +121,6 @@ def main(_): f'-m {max_steps} ' '-c false ' '-o true ' - f'-r {rsync_data} ' '-i true ') if not FLAGS.dry_run: print('Running docker container command') @@ -135,11 +129,11 @@ def main(_): else: return_code = 0 if return_code == 0: - print(f'SUCCESS: container for {framework} {workload} {algorithm} launched successfully') + print(f'SUCCESS: container for {framework} {workload} launched successfully') print(f'Command: {command}') print(f'Results will be logged to {experiment_name}') else: - print(f'Failed: container for {framework} {workload} {algorithm} failed with exit code {return_code}.') + print(f'Failed: container for {framework} {workload} failed with exit code {return_code}.') print(f'Command: {command}') wait_until_container_not_running() os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches From 0bcb9691a83ed292543a412d1e6e59b83b35fdd1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 24 Jan 2024 21:53:40 +0000 Subject: [PATCH 11/71] add script for sampling held out workloads --- scoring/generate_held_out_workloads.py | 72 +++++++++++++ scoring/run_workloads.py | 139 +++++++++++++++---------- 2 files changed, 155 insertions(+), 56 deletions(-) create mode 100644 scoring/generate_held_out_workloads.py diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py new file mode 100644 index 000000000..cc5c3df71 --- /dev/null +++ b/scoring/generate_held_out_workloads.py @@ -0,0 +1,72 @@ +from absl import app +from absl import flags +from absl import logging +import struct +import os + +import json +import jax +import jax.numpy as jnp +from algorithmic_efficiency import random_utils as prng + + +flags.DEFINE_integer('seed', None, 'Random seed for scoring.') +flags.DEFINE_string('framework', 'jax', "JAX or") +flags.DEFINE_string('output_filename', 'held_out_workloads.json', 'Path to file to record sampled held_out workloads.') +FLAGS = flags.FLAGS + + +HELD_OUT_WORKLOADS = { + 'librispeech': ['librispeech_conformer_attention_temperature', 'librispeech_conformer_layernorm', + 'librispeech_conformer_gelu'], + 'imagenet': ['imagenet_resnet_silu', 'imagenet_resnet_gelu', 'imagenet_resnet_large_bn_init', + 'imagenet_vit_gelu', 'imagenet_vit_post_ln', 'imagenet_vit_map' + ], + 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], + 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], + 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], + 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] +} + + +def save_held_out_workloads(held_out_workloads, filename): + with open(filename, "w") as f: + json.dump(held_out_workloads, f) + + +def read_held_out_workloads(filename): + with open(filename, "r") as f: + held_out_workloads = json.load(f) + return held_out_workloads + + + +def main(_): + rng_seed = FLAGS.seed + output_filename = FLAGS.output_filename + + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] + + logging.info('Using RNG seed %d', rng_seed) + rng_key = prng.PRNGKey(rng_seed) + + sampled_held_out_workloads = [] + for k, v in HELD_OUT_WORKLOADS.items(): + rng_key, rng_sub_key = prng.split(rng_key, 2) + p = jnp.array([1/len(v) for w in v]) + sampled_index = jax.random.categorical(rng_sub_key, p) + sampled_held_out_workloads.append(v[sampled_index]) + + logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") + + save_held_out_workloads(sampled_held_out_workloads, output_filename) + + +if __name__ == '__main__': + app.run(main) + + + + +print(h) \ No newline at end of file diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index dff92aa86..0f56ead78 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -11,9 +11,13 @@ from absl import flags from absl import app +from absl import logging import os import docker import time +import struct + +from algorithmic_efficiency import random_utils as prng flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') @@ -31,6 +35,15 @@ 'jax', 'Can be either PyTorch or JAX.') flags.DEFINE_boolean('dry_run', False, 'Whether or not to actually run the command') +flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') +flags.DEFINE_string('study_start_index', None, 'Start index for studies.') +flags.DEFINE_string('study_end_index', None, 'End index for studies.') +flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.') +flags.DEFINE_integer('hparam_start_index', None, 'Start index for tuning trials.') +flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') +flags.DEFINE_integer('seed', None, 'Random seed for scoring.') +flags.DEFINE_integer('submission_id', 0, 'Submission ID to generate study and hparam seeds.') +flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') FLAGS = flags.FLAGS @@ -63,18 +76,6 @@ } -HELDOUT_WORKLOADS = { - 'librispeech': ['librispeech_conformer_attention_temperature', 'librispeech_conformer_layernorm', - 'librispeech_conformer_gelu'], - 'imagenet': ['imagenet_resnet_silu', 'imagenet_resnet_gelu', 'imagenet_resnet_large_bn_init', - 'imagenet_vit_gelu', 'imagenet_vit_post_ln', 'imagenet_vit_map' - ], - 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], - 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], - 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], - 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] -} - def container_running(): docker_client = docker.from_env() containers = docker_client.containers.list() @@ -95,50 +96,76 @@ def main(_): docker_image_url = FLAGS.docker_image_url submission_path = FLAGS.submission_path tuning_search_space = FLAGS.tuning_search_space - - # For each runnable workload check if there are any containers running and if not launch next container command - for workload in WORKLOADS.keys(): - wait_until_container_not_running() - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - print('='*100) - dataset = WORKLOADS[workload]['dataset'] - max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) - mount_repo_flag = '' - if FLAGS.local: - mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' - command = ('docker run -t -d -v $HOME/data/:/data/ ' - '-v $HOME/experiment_runs/:/experiment_runs ' - '-v $HOME/experiment_runs/logs:/logs ' - f'{mount_repo_flag}' - '--gpus all --ipc=host ' - f'{docker_image_url} ' - f'-d {dataset} ' - f'-f {framework} ' - f'-s {submission_path} ' - f'-w {workload} ' - f'-t {tuning_search_space} ' - f'-e {experiment_name} ' - f'-m {max_steps} ' - '-c false ' - '-o true ' - '-i true ') - if not FLAGS.dry_run: - print('Running docker container command') - print('Container ID: ') - return_code = os.system(command) - else: - return_code = 0 - if return_code == 0: - print(f'SUCCESS: container for {framework} {workload} launched successfully') - print(f'Command: {command}') - print(f'Results will be logged to {experiment_name}') - else: - print(f'Failed: container for {framework} {workload} failed with exit code {return_code}.') - print(f'Command: {command}') - wait_until_container_not_running() - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - - print('='*100) + num_studies = FLAGS.num_studies + num_tuning_trials = FLAGS.num_tuning_trials + hparam_start_index = FLAGS.hparam_start_index + hparam_end_index = FLAGS.hparam_end_index + study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 + study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 + submission_id = FLAGS.submission_id + rng_seed = FLAGS.seed + + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] + + logging.info('Using RNG seed %d', rng_seed) + rng_key = prng.fold_in(prng.PRNGKey(rng_seed), submission_id) + rng_keys = prng.split(rng_key, 5) + + for study_index, rng_key in zip(range(study_start_index, study_end_index), rng_keys): + print('-' * 100) + print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) + print('-' * 100) + _, rng_seed = rng_key + study_dir = os.path.join(experiment_name, f'study_{index}') + + # For each runnable workload check if there are any containers running and if not launch next container command + for workload in WORKLOADS.keys(): + wait_until_container_not_running() + os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + print('='*100) + dataset = WORKLOADS[workload]['dataset'] + max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) + mount_repo_flag = '' + if FLAGS.local: + mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' + command = ('docker run -t -d -v $HOME/data/:/data/ ' + '-v $HOME/experiment_runs/:/experiment_runs ' + '-v $HOME/experiment_runs/logs:/logs ' + f'{mount_repo_flag}' + '--gpus all --ipc=host ' + f'{docker_image_url} ' + f'-d {dataset} ' + f'-f {framework} ' + f'-s {submission_path} ' + f'-w {workload} ' + f'-t {tuning_search_space} ' + f'-e {study_dir} ' + f'-m {max_steps} ' + f'--num_tuning_trials {num_tuning_trials} ' + f'--hparam_start_index {hparam_start_index} ' + f'--hparam_end_index {hparam_end_index} ' + f'--rng_seed {rng_seed} ' + '-c false ' + '-o true ' + '-i true ') + if not FLAGS.dry_run: + print('Running docker container command') + print('Container ID: ') + return_code = os.system(command) + else: + return_code = 0 + if return_code == 0: + print(f'SUCCESS: container for {framework} {workload} launched successfully') + print(f'Command: {command}') + print(f'Results will be logged to {experiment_name}') + else: + print(f'Failed: container for {framework} {workload} failed with exit code {return_code}.') + print(f'Command: {command}') + wait_until_container_not_running() + os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + + print('='*100) if __name__ == '__main__': From ce5f202e06c2d18200fc73e3a6ab6d6397b7fc88 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 01:05:26 +0000 Subject: [PATCH 12/71] add code for run workloads --- scoring/run_workloads.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 0f56ead78..b34f50ece 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -18,6 +18,7 @@ import struct from algorithmic_efficiency import random_utils as prng +from scoring.generate_held_out_workloads import read_held_out_workloads flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') From f431eefc9405adc6609127de32572d943c96435e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:30:43 +0000 Subject: [PATCH 13/71] add workload sampling --- scoring/generate_held_out_workloads.py | 18 +++---------- scoring/run_workloads.py | 37 +++++++++++++++++++------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index cc5c3df71..aa85b9d55 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -10,9 +10,9 @@ from algorithmic_efficiency import random_utils as prng -flags.DEFINE_integer('seed', None, 'Random seed for scoring.') -flags.DEFINE_string('framework', 'jax', "JAX or") +flags.DEFINE_integer('held_out_workloads_seed', None, 'Random seed for scoring.') flags.DEFINE_string('output_filename', 'held_out_workloads.json', 'Path to file to record sampled held_out workloads.') +flags.DEFINE_string('framework', 'jax', 'JAX or PyTorch') FLAGS = flags.FLAGS @@ -34,15 +34,8 @@ def save_held_out_workloads(held_out_workloads, filename): json.dump(held_out_workloads, f) -def read_held_out_workloads(filename): - with open(filename, "r") as f: - held_out_workloads = json.load(f) - return held_out_workloads - - - def main(_): - rng_seed = FLAGS.seed + rng_seed = FLAGS.held_out_workloads_seed output_filename = FLAGS.output_filename if not rng_seed: @@ -65,8 +58,3 @@ def main(_): if __name__ == '__main__': app.run(main) - - - - -print(h) \ No newline at end of file diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index b34f50ece..cfe545b42 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -16,9 +16,10 @@ import docker import time import struct +import json from algorithmic_efficiency import random_utils as prng -from scoring.generate_held_out_workloads import read_held_out_workloads +from algorithmic_efficiency.workloads.workloads import get_base_workload_name flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') @@ -77,6 +78,13 @@ } + +def read_held_out_workloads(filename): + with open(filename, "r") as f: + held_out_workloads = json.load(f) + return held_out_workloads + + def container_running(): docker_client = docker.from_env() containers = docker_client.containers.list() @@ -110,23 +118,32 @@ def main(_): rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) - rng_key = prng.fold_in(prng.PRNGKey(rng_seed), submission_id) - rng_keys = prng.split(rng_key, 5) + rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), submission_id)) + + workloads = [w for w in WORKLOADS.keys()] + + # Read held-out workloads + if FLAGS.held_out_workloads_config_path: + held_out_workloads = read_held_out_workloads(FLAGS.held_out_workloads_config_path) + workloads = workloads + held_out_workloads - for study_index, rng_key in zip(range(study_start_index, study_end_index), rng_keys): + for study_index in range(study_start_index, study_end_index): print('-' * 100) print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) print('-' * 100) - _, rng_seed = rng_key - study_dir = os.path.join(experiment_name, f'study_{index}') + rng_key, rng_subkey = prng.split(rng_key) + study_dir = os.path.join(experiment_name, f'study_{study_index}') # For each runnable workload check if there are any containers running and if not launch next container command - for workload in WORKLOADS.keys(): + for workload in workloads: + rng_subkey, run_key = prng.split(rng_subkey) + run_seed = run_key[0] # arbitrary + base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('='*100) - dataset = WORKLOADS[workload]['dataset'] - max_steps = int(WORKLOADS[workload]['max_steps'] * run_fraction) + dataset = WORKLOADS[base_workload_name]['dataset'] + max_steps = int(WORKLOADS[base_workload_name]['max_steps'] * run_fraction) mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' @@ -146,7 +163,7 @@ def main(_): f'--num_tuning_trials {num_tuning_trials} ' f'--hparam_start_index {hparam_start_index} ' f'--hparam_end_index {hparam_end_index} ' - f'--rng_seed {rng_seed} ' + f'--rng_seed {run_seed} ' '-c false ' '-o true ' '-i true ') From f260497025cf7191cfc5883cbc75be46158e5732 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:38:59 +0000 Subject: [PATCH 14/71] formatting --- scoring/generate_held_out_workloads.py | 66 +++--- scoring/run_workloads.py | 301 +++++++++++++------------ 2 files changed, 194 insertions(+), 173 deletions(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index aa85b9d55..794a451c2 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -9,52 +9,64 @@ import jax.numpy as jnp from algorithmic_efficiency import random_utils as prng - -flags.DEFINE_integer('held_out_workloads_seed', None, 'Random seed for scoring.') -flags.DEFINE_string('output_filename', 'held_out_workloads.json', 'Path to file to record sampled held_out workloads.') +flags.DEFINE_integer('held_out_workloads_seed', + None, + 'Random seed for scoring.') +flags.DEFINE_string('output_filename', + 'held_out_workloads.json', + 'Path to file to record sampled held_out workloads.') flags.DEFINE_string('framework', 'jax', 'JAX or PyTorch') FLAGS = flags.FLAGS - HELD_OUT_WORKLOADS = { - 'librispeech': ['librispeech_conformer_attention_temperature', 'librispeech_conformer_layernorm', - 'librispeech_conformer_gelu'], - 'imagenet': ['imagenet_resnet_silu', 'imagenet_resnet_gelu', 'imagenet_resnet_large_bn_init', - 'imagenet_vit_gelu', 'imagenet_vit_post_ln', 'imagenet_vit_map' + 'librispeech': [ + 'librispeech_conformer_attention_temperature', + 'librispeech_conformer_layernorm', + 'librispeech_conformer_gelu' + ], + 'imagenet': [ + 'imagenet_resnet_silu', + 'imagenet_resnet_gelu', + 'imagenet_resnet_large_bn_init', + 'imagenet_vit_gelu', + 'imagenet_vit_post_ln', + 'imagenet_vit_map' ], 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'], 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'], 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'], - 'criteo1tb':['criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'] + 'criteo1tb': [ + 'criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet' + ] } def save_held_out_workloads(held_out_workloads, filename): - with open(filename, "w") as f: - json.dump(held_out_workloads, f) + with open(filename, "w") as f: + json.dump(held_out_workloads, f) def main(_): - rng_seed = FLAGS.held_out_workloads_seed - output_filename = FLAGS.output_filename + rng_seed = FLAGS.held_out_workloads_seed + output_filename = FLAGS.output_filename + + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] - if not rng_seed: - rng_seed = struct.unpack('I', os.urandom(4))[0] - - logging.info('Using RNG seed %d', rng_seed) - rng_key = prng.PRNGKey(rng_seed) + logging.info('Using RNG seed %d', rng_seed) + rng_key = prng.PRNGKey(rng_seed) - sampled_held_out_workloads = [] - for k, v in HELD_OUT_WORKLOADS.items(): - rng_key, rng_sub_key = prng.split(rng_key, 2) - p = jnp.array([1/len(v) for w in v]) - sampled_index = jax.random.categorical(rng_sub_key, p) - sampled_held_out_workloads.append(v[sampled_index]) + sampled_held_out_workloads = [] + for k, v in HELD_OUT_WORKLOADS.items(): + rng_key, rng_sub_key = prng.split(rng_key, 2) + p = jnp.array([1 / len(v) for w in v]) + sampled_index = jax.random.categorical(rng_sub_key, p) + sampled_held_out_workloads.append(v[sampled_index]) - logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") + logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") - save_held_out_workloads(sampled_held_out_workloads, output_filename) + save_held_out_workloads(sampled_held_out_workloads, output_filename) if __name__ == '__main__': - app.run(main) + app.run(main) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index cfe545b42..4f72ebedb 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -14,178 +14,187 @@ from absl import logging import os import docker -import time +import time import struct import json from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency.workloads.workloads import get_base_workload_name - -flags.DEFINE_string('docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') -flags.DEFINE_integer('run_percentage', 100, 'Percentage of max num steps to run for.') -flags.DEFINE_string('experiment_name', 'my_experiment', 'Name of top sub directory in experiment dir.') -flags.DEFINE_boolean('rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.') +flags.DEFINE_string( + 'docker_image_url', + 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', + 'URL to docker image') +flags.DEFINE_integer('run_percentage', + 100, + 'Percentage of max num steps to run for.') +flags.DEFINE_string('experiment_name', + 'my_experiment', + 'Name of top sub directory in experiment dir.') +flags.DEFINE_boolean('rsync_data', + True, + 'Whether or not to transfer the data from GCP w rsync.') flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') -flags.DEFINE_string('submission_path', - 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', - 'Path to reference submission.') -flags.DEFINE_string('tuning_search_space', - 'prize_qualification_baselines/external_tuning/tuning_search_space.json', - 'Path to tuning search space.') -flags.DEFINE_string('framework', - 'jax', - 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean('dry_run', False, 'Whether or not to actually run the command') +flags.DEFINE_string( + 'submission_path', + 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py', + 'Path to reference submission.') +flags.DEFINE_string( + 'tuning_search_space', + 'prize_qualification_baselines/external_tuning/tuning_search_space.json', + 'Path to tuning search space.') +flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') +flags.DEFINE_boolean('dry_run', + False, + 'Whether or not to actually run the command') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_string('study_start_index', None, 'Start index for studies.') flags.DEFINE_string('study_end_index', None, 'End index for studies.') flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.') -flags.DEFINE_integer('hparam_start_index', None, 'Start index for tuning trials.') +flags.DEFINE_integer('hparam_start_index', + None, + 'Start index for tuning trials.') flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') flags.DEFINE_integer('seed', None, 'Random seed for scoring.') -flags.DEFINE_integer('submission_id', 0, 'Submission ID to generate study and hparam seeds.') -flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') - +flags.DEFINE_integer('submission_id', + 0, + 'Submission ID to generate study and hparam seeds.') +flags.DEFINE_string('held_out_workloads_config_path', + None, + 'Path to config containing held-out workloads') FLAGS = flags.FLAGS - -DATASETS = ['imagenet', - 'fastmri', - 'ogbg', - 'wmt', - 'librispeech', - 'criteo1tb'] +DATASETS = ['imagenet', 'fastmri', 'ogbg', 'wmt', 'librispeech', 'criteo1tb'] WORKLOADS = { - 'imagenet_resnet': {'max_steps': 186_666, - 'dataset': 'imagenet'}, - 'imagenet_vit': {'max_steps': 186_666, - 'dataset': 'imagenet'}, - 'fastmri': {'max_steps': 36_189, - 'dataset': 'fastmri'}, - 'ogbg': {'max_steps': 80_000, - 'dataset': 'ogbg'}, - 'wmt': {'max_steps': 133_333, - 'dataset': 'wmt'}, - 'librispeech_deepspeech': {'max_steps': 48_000, - 'dataset': 'librispeech'}, - 'criteo1tb': {'max_steps': 10_666, - 'dataset': 'criteo1tb'}, - 'librispeech_conformer': {'max_steps': 80_000, - 'dataset': 'librispeech'}, - } - + 'imagenet_resnet': {'max_steps': 186_666, 'dataset': 'imagenet'}, + 'imagenet_vit': {'max_steps': 186_666, 'dataset': 'imagenet'}, + 'fastmri': {'max_steps': 36_189, 'dataset': 'fastmri'}, + 'ogbg': {'max_steps': 80_000, 'dataset': 'ogbg'}, + 'wmt': {'max_steps': 133_333, 'dataset': 'wmt'}, + 'librispeech_deepspeech': {'max_steps': 48_000, 'dataset': 'librispeech'}, + 'criteo1tb': {'max_steps': 10_666, 'dataset': 'criteo1tb'}, + 'librispeech_conformer': {'max_steps': 80_000, 'dataset': 'librispeech'}, +} def read_held_out_workloads(filename): - with open(filename, "r") as f: - held_out_workloads = json.load(f) - return held_out_workloads + with open(filename, "r") as f: + held_out_workloads = json.load(f) + return held_out_workloads def container_running(): - docker_client = docker.from_env() - containers = docker_client.containers.list() - if len(containers) == 0: - return False - else: - return True - -def wait_until_container_not_running(sleep_interval=5*60): - while container_running(): - time.sleep(sleep_interval) - return - + docker_client = docker.from_env() + containers = docker_client.containers.list() + if len(containers) == 0: + return False + else: + return True + + +def wait_until_container_not_running(sleep_interval=5 * 60): + while container_running(): + time.sleep(sleep_interval) + return + + def main(_): - framework = FLAGS.framework - run_fraction = FLAGS.run_percentage/100. - experiment_name=FLAGS.experiment_name - docker_image_url = FLAGS.docker_image_url - submission_path = FLAGS.submission_path - tuning_search_space = FLAGS.tuning_search_space - num_studies = FLAGS.num_studies - num_tuning_trials = FLAGS.num_tuning_trials - hparam_start_index = FLAGS.hparam_start_index - hparam_end_index = FLAGS.hparam_end_index - study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 - study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 - submission_id = FLAGS.submission_id - rng_seed = FLAGS.seed - - if not rng_seed: - rng_seed = struct.unpack('I', os.urandom(4))[0] - - logging.info('Using RNG seed %d', rng_seed) - rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), submission_id)) - - workloads = [w for w in WORKLOADS.keys()] - - # Read held-out workloads - if FLAGS.held_out_workloads_config_path: - held_out_workloads = read_held_out_workloads(FLAGS.held_out_workloads_config_path) - workloads = workloads + held_out_workloads - - for study_index in range(study_start_index, study_end_index): - print('-' * 100) - print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) - print('-' * 100) - rng_key, rng_subkey = prng.split(rng_key) - study_dir = os.path.join(experiment_name, f'study_{study_index}') - - # For each runnable workload check if there are any containers running and if not launch next container command - for workload in workloads: - rng_subkey, run_key = prng.split(rng_subkey) - run_seed = run_key[0] # arbitrary - base_workload_name = get_base_workload_name(workload) - wait_until_container_not_running() - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - print('='*100) - dataset = WORKLOADS[base_workload_name]['dataset'] - max_steps = int(WORKLOADS[base_workload_name]['max_steps'] * run_fraction) - mount_repo_flag = '' - if FLAGS.local: - mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' - command = ('docker run -t -d -v $HOME/data/:/data/ ' - '-v $HOME/experiment_runs/:/experiment_runs ' - '-v $HOME/experiment_runs/logs:/logs ' - f'{mount_repo_flag}' - '--gpus all --ipc=host ' - f'{docker_image_url} ' - f'-d {dataset} ' - f'-f {framework} ' - f'-s {submission_path} ' - f'-w {workload} ' - f'-t {tuning_search_space} ' - f'-e {study_dir} ' - f'-m {max_steps} ' - f'--num_tuning_trials {num_tuning_trials} ' - f'--hparam_start_index {hparam_start_index} ' - f'--hparam_end_index {hparam_end_index} ' - f'--rng_seed {run_seed} ' - '-c false ' - '-o true ' - '-i true ') - if not FLAGS.dry_run: - print('Running docker container command') - print('Container ID: ') - return_code = os.system(command) - else: - return_code = 0 - if return_code == 0: - print(f'SUCCESS: container for {framework} {workload} launched successfully') - print(f'Command: {command}') - print(f'Results will be logged to {experiment_name}') - else: - print(f'Failed: container for {framework} {workload} failed with exit code {return_code}.') - print(f'Command: {command}') - wait_until_container_not_running() - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - - print('='*100) + framework = FLAGS.framework + run_fraction = FLAGS.run_percentage / 100. + experiment_name = FLAGS.experiment_name + docker_image_url = FLAGS.docker_image_url + submission_path = FLAGS.submission_path + tuning_search_space = FLAGS.tuning_search_space + num_studies = FLAGS.num_studies + num_tuning_trials = FLAGS.num_tuning_trials + hparam_start_index = FLAGS.hparam_start_index + hparam_end_index = FLAGS.hparam_end_index + study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 + study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 + submission_id = FLAGS.submission_id + rng_seed = FLAGS.seed + + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] + + logging.info('Using RNG seed %d', rng_seed) + rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), submission_id)) + + workloads = [w for w in WORKLOADS.keys()] + + # Read held-out workloads + if FLAGS.held_out_workloads_config_path: + held_out_workloads = read_held_out_workloads( + FLAGS.held_out_workloads_config_path) + workloads = workloads + held_out_workloads + + for study_index in range(study_start_index, study_end_index): + print('-' * 100) + print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) + print('-' * 100) + rng_key, rng_subkey = prng.split(rng_key) + study_dir = os.path.join(experiment_name, f'study_{study_index}') + + # For each runnable workload check if there are any containers running and if not launch next container command + for workload in workloads: + rng_subkey, run_key = prng.split(rng_subkey) + run_seed = run_key[0] # arbitrary + base_workload_name = get_base_workload_name(workload) + wait_until_container_not_running() + os.system( + "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + print('=' * 100) + dataset = WORKLOADS[base_workload_name]['dataset'] + max_steps = int(WORKLOADS[base_workload_name]['max_steps'] * run_fraction) + mount_repo_flag = '' + if FLAGS.local: + mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' + command = ('docker run -t -d -v $HOME/data/:/data/ ' + '-v $HOME/experiment_runs/:/experiment_runs ' + '-v $HOME/experiment_runs/logs:/logs ' + f'{mount_repo_flag}' + '--gpus all --ipc=host ' + f'{docker_image_url} ' + f'-d {dataset} ' + f'-f {framework} ' + f'-s {submission_path} ' + f'-w {workload} ' + f'-t {tuning_search_space} ' + f'-e {study_dir} ' + f'-m {max_steps} ' + f'--num_tuning_trials {num_tuning_trials} ' + f'--hparam_start_index {hparam_start_index} ' + f'--hparam_end_index {hparam_end_index} ' + f'--rng_seed {run_seed} ' + '-c false ' + '-o true ' + '-i true ') + if not FLAGS.dry_run: + print('Running docker container command') + print('Container ID: ') + return_code = os.system(command) + else: + return_code = 0 + if return_code == 0: + print( + f'SUCCESS: container for {framework} {workload} launched successfully' + ) + print(f'Command: {command}') + print(f'Results will be logged to {experiment_name}') + else: + print( + f'Failed: container for {framework} {workload} failed with exit code {return_code}.' + ) + print(f'Command: {command}') + wait_until_container_not_running() + os.system( + "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches + + print('=' * 100) if __name__ == '__main__': - app.run(main) \ No newline at end of file + app.run(main) From 1a41f8b82957012c2a13ae9d76d61be91348e061 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:39:28 +0000 Subject: [PATCH 15/71] imports --- scoring/generate_held_out_workloads.py | 9 +++++---- scoring/performance_profile.py | 2 +- scoring/run_workloads.py | 13 +++++++------ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index 794a451c2..c61e637bd 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -1,12 +1,13 @@ +import json +import os +import struct + from absl import app from absl import flags from absl import logging -import struct -import os - -import json import jax import jax.numpy as jnp + from algorithmic_efficiency import random_utils as prng flags.DEFINE_integer('held_out_workloads_seed', diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 9322dfaa7..ef4e97f88 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -35,8 +35,8 @@ import numpy as np import pandas as pd -import algorithmic_efficiency.workloads.workloads as workloads_registry from algorithmic_efficiency.workloads.workloads import get_base_workload_name +import algorithmic_efficiency.workloads.workloads as workloads_registry from scoring import scoring_utils WORKLOADS = workloads_registry.WORKLOADS diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 4f72ebedb..7ccd0ca9b 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -9,17 +9,18 @@ --tuning_search_space """ -from absl import flags -from absl import app -from absl import logging +import json import os -import docker -import time import struct -import json +import time + +from absl import app +from absl import flags +from absl import logging from algorithmic_efficiency import random_utils as prng from algorithmic_efficiency.workloads.workloads import get_base_workload_name +import docker flags.DEFINE_string( 'docker_image_url', From 87df162a3ff5e2ae6b04b74fcbd8226016caeea8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:47:29 +0000 Subject: [PATCH 16/71] make seed splitting parallelizable --- scoring/run_workloads.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 7ccd0ca9b..f285d66d4 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -131,11 +131,12 @@ def main(_): FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - for study_index in range(study_start_index, study_end_index): + rng_subkeys = prng.split(rng_key, num_studies)[study_start_index:study_end_index:] + + for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) print('-' * 100) - rng_key, rng_subkey = prng.split(rng_key) study_dir = os.path.join(experiment_name, f'study_{study_index}') # For each runnable workload check if there are any containers running and if not launch next container command From 9d9cdb9dab98efb0d0f90a7cc1e813a89bd8d95a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:49:46 +0000 Subject: [PATCH 17/71] fix --- scoring/score_submissions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 13a0dc9b2..1f7a3a1e7 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -12,7 +12,7 @@ from scoring import performance_profile flags.DEFINE_string( - 'submission_directory, + 'submission_directory', None, 'Path to submission directory containing experiment directories.') flags.DEFINE_string('output_dir', From 17753071c72770862b9db08ed9493546689c2126 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 05:50:49 +0000 Subject: [PATCH 18/71] formatting --- scoring/run_workloads.py | 3 ++- scoring/score_submissions.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index f285d66d4..47a47ca58 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -131,7 +131,8 @@ def main(_): FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - rng_subkeys = prng.split(rng_key, num_studies)[study_start_index:study_end_index:] + rng_subkeys = prng.split(rng_key, + num_studies)[study_start_index:study_end_index:] for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 1f7a3a1e7..106c6b1da 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -29,8 +29,7 @@ flags.DEFINE_boolean( 'self_tuning_ruleset', False, - 'Whether to score on self-tuning ruleset or externally tuned ruleset' -) + 'Whether to score on self-tuning ruleset or externally tuned ruleset') FLAGS = flags.FLAGS @@ -60,6 +59,7 @@ def get_summary_df(workload, workload_df): return summary_df + def print_submission_summary(df): dfs = [] for workload, group in df.groupby('workload'): From 2a1170858bf003077c41a42e08daad1571ef37d3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 22:41:56 +0000 Subject: [PATCH 19/71] held out workloads example --- scoring/held_out_workloads_example.json | 1 + 1 file changed, 1 insertion(+) create mode 100644 scoring/held_out_workloads_example.json diff --git a/scoring/held_out_workloads_example.json b/scoring/held_out_workloads_example.json new file mode 100644 index 000000000..2b3d6d6b2 --- /dev/null +++ b/scoring/held_out_workloads_example.json @@ -0,0 +1 @@ +["librispeech_conformer_gelu", "imagenet_resnet_silu", "ogbg_gelu", "wmt_post_ln", "fastmri_model_size", "criteo1tb_layernorm"] \ No newline at end of file From a8385a21a3c402dc643bd27e31260343ef81e3b5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 22:44:31 +0000 Subject: [PATCH 20/71] add docker for run_workloads.py --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 20139d4c0..4fa84951f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ setup_requires = # Dependencies of the project: install_requires = absl-py==1.4.0 + docker==7.0.0 numpy>=1.23 pandas>=2.0.1 tensorflow==2.12.0 From ffddbdc1cb7020f16567ee4b4778da353ca2bdb6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 22:52:23 +0000 Subject: [PATCH 21/71] fix run_workloads.py --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 47a47ca58..6bf09469c 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -132,7 +132,7 @@ def main(_): workloads = workloads + held_out_workloads rng_subkeys = prng.split(rng_key, - num_studies)[study_start_index:study_end_index:] + num_studies)[:num_studies:] for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) From 91cdf34351d0f9213918495a14f362fb57e5ad7d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 22:53:19 +0000 Subject: [PATCH 22/71] fix --- scoring/run_workloads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 6bf09469c..19291ead0 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -49,8 +49,8 @@ False, 'Whether or not to actually run the command') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') -flags.DEFINE_string('study_start_index', None, 'Start index for studies.') -flags.DEFINE_string('study_end_index', None, 'End index for studies.') +flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') +flags.DEFINE_integer('study_end_index', None, 'End index for studies.') flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.') flags.DEFINE_integer('hparam_start_index', None, From 95572ad3c1af539b0e93601c9108740aca25e22e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 23:38:27 +0000 Subject: [PATCH 23/71] add rng seed to startup.sh docker script --- docker/scripts/startup.sh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 2bd8abf33..b06375c34 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -34,6 +34,10 @@ function usage() { from internal GCP bucket. -i | --internal_contributor: If true, allow rsync of data and transfer of experiment results with GCP project. + --num_tuning_trials Number of tuning trials for externally tuned ruleset submission. + --hparam_start_index Should be > 0 and < num_tuning_trials - 1. + --hparam_end_index Should be > 0 and < num_tuning_trials - 1. + --rng_seed RNG seed to pass to workload submission_runner. USAGE exit 1 } @@ -47,6 +51,7 @@ SAVE_CHECKPOINTS="true" NUM_TUNING_TRIALS="1" HPARAM_START_INDEX="None" HPARAM_END_INDEX="None" +RNG_SEED="None" # Pass flag while [ "$1" != "" ]; do @@ -115,6 +120,10 @@ while [ "$1" != "" ]; do shift HPARAM_END_INDEX=$1 ;; + --rng_seed) + shift + RNG_SEED=$1 + ;; *) usage exit 1 @@ -222,6 +231,7 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --num_tuning_trials={NUM_TUNING_TRIALS} \ --hparam_start_index={HPARAM_START_INDEX} \ --hparam_end_index={HPARAM_END_INDEX} \ + --rng_seed={RNG_SEED} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ ${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}" From d577d5c5758897937a590288f586cf01307f1267 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 23:51:13 +0000 Subject: [PATCH 24/71] fix --- docker/scripts/startup.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index b06375c34..b5ad18941 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -228,9 +228,9 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --experiment_name=${EXPERIMENT_NAME} \ --overwrite=${OVERWRITE} \ --save_checkpoints=${SAVE_CHECKPOINTS} \ - --num_tuning_trials={NUM_TUNING_TRIALS} \ - --hparam_start_index={HPARAM_START_INDEX} \ - --hparam_end_index={HPARAM_END_INDEX} \ + --num_tuning_trials=${NUM_TUNING_TRIALS} \ + --hparam_start_index=${HPARAM_START_INDEX} \ + --hparam_end_index=${HPARAM_END_INDEX} \ --rng_seed={RNG_SEED} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ From 91ff705f67e2b7ce08d12d150e323b6cf0fa6e7b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 25 Jan 2024 23:53:06 +0000 Subject: [PATCH 25/71] fix --- docker/scripts/startup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index b5ad18941..c0328ffb4 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -231,7 +231,7 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --num_tuning_trials=${NUM_TUNING_TRIALS} \ --hparam_start_index=${HPARAM_START_INDEX} \ --hparam_end_index=${HPARAM_END_INDEX} \ - --rng_seed={RNG_SEED} \ + --rng_seed=${RNG_SEED} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ ${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}" From 296dc1ecc3ecb6035ce3b9faefbe84ebf89f6fdc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:07:52 +0000 Subject: [PATCH 26/71] fix --- docker/scripts/startup.sh | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index c0328ffb4..b4eff52ff 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -48,10 +48,6 @@ HOME_DIR="" RSYNC_DATA="true" OVERWRITE="false" SAVE_CHECKPOINTS="true" -NUM_TUNING_TRIALS="1" -HPARAM_START_INDEX="None" -HPARAM_END_INDEX="None" -RNG_SEED="None" # Pass flag while [ "$1" != "" ]; do @@ -204,6 +200,22 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then MAX_STEPS_FLAG="--max_global_steps=${MAX_GLOBAL_STEPS}" fi + if [[ ! -z ${NUM_TUNING_TRIALS+x} ]]; then + NUM_TUNING_TRIALS_FLAG="--num_tuning_trials=${NUM_TUNING_TRIALS}" + fi + + if [[ ! -z ${HPARAM_START_INDEX+x} ]]; then + HPARAM_START_INDEX_FLAG="--hparam_start_index=${HPARAM_START_INDEX}" + fi + + if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then + HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}" + fi + + if [[ ! -z ${RNG_SEED+x} ]]; then + RNG_SEED_FLAG="--rng_seed=${RNG_SEED}" + fi + # Define special flags for imagenet and librispeech workloads if [[ ${DATASET} == "imagenet" ]]; then SPECIAL_FLAGS="--imagenet_v2_data_dir=${DATA_DIR}" @@ -228,10 +240,10 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then --experiment_name=${EXPERIMENT_NAME} \ --overwrite=${OVERWRITE} \ --save_checkpoints=${SAVE_CHECKPOINTS} \ - --num_tuning_trials=${NUM_TUNING_TRIALS} \ - --hparam_start_index=${HPARAM_START_INDEX} \ - --hparam_end_index=${HPARAM_END_INDEX} \ - --rng_seed=${RNG_SEED} \ + ${NUM_TUNING_TRIALS_FLAG} \ + ${HPARAM_START_INDEX_FLAG} \ + ${HPARAM_END_INDEX_FLAG} \ + ${RNG_SEED_FLAG} \ ${MAX_STEPS_FLAG} \ ${SPECIAL_FLAGS} \ ${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}" From a5b1154343e96b822afd9abcc47ae005728f12f6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:30:14 +0000 Subject: [PATCH 27/71] fix --- .../external_tuning/tuning_search_space.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json index 65562905a..b5aff94a2 100644 --- a/prize_qualification_baselines/external_tuning/tuning_search_space.json +++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -10,6 +10,7 @@ }, { "dropout_rate": 0.0, + "label_smoothing": 0.1, "label_smoothing": 0.2, "learning_rate": 0.0008445074561975979, "one_minus_beta1": 0.11042418465, @@ -19,6 +20,7 @@ }, { "dropout_rate": 0.0, + "label_smoothing": 0.1, "learning_rate": 0.001308209823469072, "one_minus_beta1": 0.02686663061, "beta2": 0.9981232922116359, @@ -27,6 +29,7 @@ }, { "dropout_rate": 0.0, + "label_smoothing": 0.1, "learning_rate": 0.004958460849689891, "one_minus_beta1": 0.13625575743, "beta2": 0.6291854735396584, From 226544dc816c719334f08096f394a80c8f168a4e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:30:25 +0000 Subject: [PATCH 28/71] fix --- .../external_tuning/tuning_search_space.json | 1 + 1 file changed, 1 insertion(+) diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json index b5aff94a2..910b9a70a 100644 --- a/prize_qualification_baselines/external_tuning/tuning_search_space.json +++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -38,6 +38,7 @@ }, { "dropout_rate": 0.1, + "label_smoothing": 0.1, "learning_rate": 0.0017486387539278373, "one_minus_beta1": 0.06733926164, "beta2": 0.9955159689799007, From 6faad0431001a9a645dd4ff45e11febdbff5eb92 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:31:54 +0000 Subject: [PATCH 29/71] fix --- .../external_tuning/tuning_search_space.json | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json index 910b9a70a..199f77041 100644 --- a/prize_qualification_baselines/external_tuning/tuning_search_space.json +++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -10,7 +10,6 @@ }, { "dropout_rate": 0.0, - "label_smoothing": 0.1, "label_smoothing": 0.2, "learning_rate": 0.0008445074561975979, "one_minus_beta1": 0.11042418465, @@ -20,7 +19,7 @@ }, { "dropout_rate": 0.0, - "label_smoothing": 0.1, + "label_smoothing": 0.0, "learning_rate": 0.001308209823469072, "one_minus_beta1": 0.02686663061, "beta2": 0.9981232922116359, @@ -29,7 +28,7 @@ }, { "dropout_rate": 0.0, - "label_smoothing": 0.1, + "label_smoothing": 0.0, "learning_rate": 0.004958460849689891, "one_minus_beta1": 0.13625575743, "beta2": 0.6291854735396584, @@ -38,7 +37,7 @@ }, { "dropout_rate": 0.1, - "label_smoothing": 0.1, + "label_smoothing": 0.0, "learning_rate": 0.0017486387539278373, "one_minus_beta1": 0.06733926164, "beta2": 0.9955159689799007, From 9e7def9ef7d1ab0e502a3225dc532ab73df162e0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 00:56:34 +0000 Subject: [PATCH 30/71] fix log message --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 19291ead0..f04e8f8df 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -136,7 +136,7 @@ def main(_): for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) - print('*' * 40, f'Starting study {study_index}/{num_studies}', '*' * 40) + print('*' * 40, f'Starting study {study_index}/{num_studies - 1}', '*' * 40) print('-' * 100) study_dir = os.path.join(experiment_name, f'study_{study_index}') From 9b410b71cc75b1ba11017fa679f9582cc71b3a5f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:25:30 +0000 Subject: [PATCH 31/71] fix --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index f04e8f8df..72b43dd9f 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -136,7 +136,7 @@ def main(_): for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) - print('*' * 40, f'Starting study {study_index}/{num_studies - 1}', '*' * 40) + print('*' * 40, f'Starting study {study_index + 1}/{num_studies}', '*' * 40) print('-' * 100) study_dir = os.path.join(experiment_name, f'study_{study_index}') From 7634a0bc6c63798f3c93a5e7bd143237a97c5925 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:29:22 +0000 Subject: [PATCH 32/71] debug --- docker/scripts/startup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index b4eff52ff..2f1ebb4b7 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -205,7 +205,7 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then fi if [[ ! -z ${HPARAM_START_INDEX+x} ]]; then - HPARAM_START_INDEX_FLAG="--hparam_start_index=${HPARAM_START_INDEX}" + HPARAM_START_INDEX_FLAG="--hparam_start_index=blabla" fi if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then From 235bc69de749670012fd48901ca1c72daf617b11 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:43:46 +0000 Subject: [PATCH 33/71] debugging --- docker/scripts/startup.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 2f1ebb4b7..5e7c74988 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -205,10 +205,11 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then fi if [[ ! -z ${HPARAM_START_INDEX+x} ]]; then - HPARAM_START_INDEX_FLAG="--hparam_start_index=blabla" + HPARAM_START_INDEX_FLAG="--hparam_start_index=${HPARAM_START_INDEX}" fi if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then + echo "SETTING FLAGGGGGG" HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}" fi From a8d04ccd1ada1bcd97be84b23858440430b1aca8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:45:24 +0000 Subject: [PATCH 34/71] debugging --- docker/scripts/startup.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 5e7c74988..914e7d640 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -210,6 +210,7 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then echo "SETTING FLAGGGGGG" + echo ${HPARAM_END_INDEX} HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}" fi From b2571b230c9808bad1089ceb41df53fac4c106d5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:50:30 +0000 Subject: [PATCH 35/71] fix --- scoring/run_workloads.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 72b43dd9f..3dea262d4 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -110,8 +110,10 @@ def main(_): tuning_search_space = FLAGS.tuning_search_space num_studies = FLAGS.num_studies num_tuning_trials = FLAGS.num_tuning_trials - hparam_start_index = FLAGS.hparam_start_index - hparam_end_index = FLAGS.hparam_end_index + if FLAGS.hparam_start_index: + hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} ' + if FLAGS.hparam_end_index: + hparam_end_index_flag = f'--hparam_end_index {FLAGS.hparam_end_index} ' study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 submission_id = FLAGS.submission_id @@ -168,8 +170,8 @@ def main(_): f'-e {study_dir} ' f'-m {max_steps} ' f'--num_tuning_trials {num_tuning_trials} ' - f'--hparam_start_index {hparam_start_index} ' - f'--hparam_end_index {hparam_end_index} ' + f'{hparam_start_index_flag} ' + f'{hparam_end_index_flag} ' f'--rng_seed {run_seed} ' '-c false ' '-o true ' From 18bc3474c158b63b32991306bc06553520d1e84d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:50:55 +0000 Subject: [PATCH 36/71] remove debugging statemetns --- docker/scripts/startup.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 914e7d640..b4eff52ff 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -209,8 +209,6 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then fi if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then - echo "SETTING FLAGGGGGG" - echo ${HPARAM_END_INDEX} HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}" fi From 4a986985d3c535882c64c4163868f008e0d280e2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 18:59:48 +0000 Subject: [PATCH 37/71] fix --- scoring/run_workloads.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 3dea262d4..82afbfb7a 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -110,6 +110,8 @@ def main(_): tuning_search_space = FLAGS.tuning_search_space num_studies = FLAGS.num_studies num_tuning_trials = FLAGS.num_tuning_trials + hparam_start_index_flag = '' + hparam_end_index_flag = '' if FLAGS.hparam_start_index: hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} ' if FLAGS.hparam_end_index: From 4d38e55e1e2d658765732b1e276137b0472745e1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 26 Jan 2024 19:08:14 +0000 Subject: [PATCH 38/71] formatting --- scoring/run_workloads.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 82afbfb7a..083dafb6a 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -135,8 +135,7 @@ def main(_): FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - rng_subkeys = prng.split(rng_key, - num_studies)[:num_studies:] + rng_subkeys = prng.split(rng_key, num_studies)[:num_studies:] for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) From 4d413f47e155fc8b14d90f02248df90ac2955a0d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Jan 2024 01:28:51 +0000 Subject: [PATCH 39/71] take into account median of studies for scoring --- scoring/performance_profile.py | 47 ++++++++++++-------- scoring/score_submissions.py | 9 ++-- scoring/scoring_utils.py | 78 ++++++++++++++++++---------------- 3 files changed, 78 insertions(+), 56 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index ef4e97f88..9c334ee22 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -50,6 +50,7 @@ NUM_BASE_WORKLOADS = 8 NUM_VARIANT_WORKLOADS = 6 NUM_TRIALS = 5 +NUM_STUDIES = 5 MIN_EVAL_METRICS = [ 'ce_loss', @@ -151,6 +152,7 @@ def get_index_that_reaches_target(workload_df, else: index_reached = target_reached.apply(np.argmax) trial = index_reached.idxmin() + print(trial, index_reached[trial]) return trial, index_reached[trial] @@ -182,27 +184,40 @@ def get_times_for_submission(submission, f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' f'but found {num_workloads} workloads.') for workload, group in submission.groupby('workload'): - num_trials = len(group) - if num_trials != NUM_TRIALS and not self_tuning_ruleset: - if strict: - raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') - else: - logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload) - trial_idx, time_idx = get_index_that_reaches_target( - group, validation_metric, validation_target) - if time_idx > -1: - time_val = group[time_col].loc[trial_idx][time_idx] - else: - time_val = float('inf') + time_vals_per_study = [] + num_studies = len(group.groupby('study')) + if num_studies != NUM_STUDIES: + if strict: + raise ValueError(f'Expecting {NUM_STUDIES} trials for workload ' + f'{workload} but found {num_studies} trials.') + else: + logging.warning(f'Expecting {NUM_STUDIES} trials for workload ' + f'{workload} but found {num_studies} trials.') + for study, group in group.groupby('study'): + num_trials = len(group) + if num_trials != NUM_TRIALS and not self_tuning_ruleset: + if strict: + raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') + else: + logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials.') + + trial_idx, time_idx = get_index_that_reaches_target( + group, validation_metric, validation_target) + if time_idx > -1: + time_val = group[time_col].loc[trial_idx][time_idx] + else: + time_val = float('inf') + time_vals_per_study.append(time_val) + workloads.append({ 'submission': submission_name, 'workload': workload, - time_col: time_val, + time_col: np.median(time_val), }) if verbosity > 0: @@ -215,9 +230,7 @@ def get_times_for_submission(submission, print('Submission did not reach target') df = pd.DataFrame.from_records(workloads) - print(df) df = df.pivot(index='submission', columns='workload', values=time_col) - print(time_col) return df diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 106c6b1da..67e0317ae 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -73,9 +73,12 @@ def print_submission_summary(df): def main(_): results = {} - for submission in os.path.listdir(FLAGS.submission_directory): - df = scoring_utils.get_experiment_df(FLAGS.experiment_path) + for submission in os.listdir(FLAGS.submission_directory): + experiment_path = os.path.join(FLAGS.submission_directory, submission) + df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df + print('SUMMARY ') + print(df.keys()) print_submission_summary(df) if FLAGS.compute_performance_profiles: @@ -100,5 +103,5 @@ def main(_): if __name__ == '__main__': - flags.mark_flag_as_required('experiment_path') + # flags.mark_flag_as_required('submission_directory') app.run(main) diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 8252c75a9..1fff9e6b5 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -150,52 +150,58 @@ def get_experiment_df(experiment_dir): collected together. The directory structure is assumed to be: + experiment_dir - + - + - - eval_measurements.csv + + study + + + + + - eval_measurements.csv Returns: df: DataFrame where indices are trials, columns are - metric names and values are lists. + metric names and values are lists of length num evals. e.g - +----+-----------+-----------------------------+--------------------+--------------------+ - | | workload | trial | validation/accuracy| score | - |----+-----------+-----------------------------+--------------------+--------------------| - | 0 | mnist_jax | (trial_1, ) | [0.0911, 0.0949] | [10.6396, 10.6464] | - +----+-----------+-----------------------------+--------------------+--------------------+ + +----+-----------+--------+----------------------------+--------------------+--------------------+ + | | workload | study |trial | validation/accuracy| score | + |----+-----------+--------+----------------------------+--------------------+--------------------| + | 0 | mnist_jax | 0 |(trial_1, ) | [0.0911, 0.0949] | [10.6396, 10.6464] | + +----+-----------+--------+----------------------------+--------------------+--------------------+ """ df = pd.DataFrame() paths = filter( lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir, glob.glob(f"{experiment_dir}*")) for experiment_dir in list(paths): - workload_dirs = os.listdir(experiment_dir) - for workload in workload_dirs: - data = { - 'workload': workload, - } - trial_dirs = [ - t for t in os.listdir(os.path.join(experiment_dir, workload)) - if re.match(TRIAL_DIR_REGEX, t) - ] - for trial in trial_dirs: - eval_measurements_filepath = os.path.join( - experiment_dir, - workload, - trial, - MEASUREMENTS_FILENAME, - ) - try: - trial_df = pd.read_csv(eval_measurements_filepath) - except FileNotFoundError as e: - logging.info(f'Could not read {eval_measurements_filepath}') - continue - data['trial'] = (trial, experiment_dir) - for column in trial_df.columns: - values = trial_df[column].to_numpy() - data[column] = values - trial_df = pd.DataFrame([data]) - df = pd.concat([df, trial_df], ignore_index=True) + study_dirs = os.listdir(experiment_dir) + for study_dir in study_dirs: + workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir)) + for workload in workload_dirs: + data = { + 'workload': workload, + } + logging.info(os.path.join(experiment_dir, study_dir, workload)) + trial_dirs = [ + t for t in os.listdir(os.path.join(experiment_dir, study_dir, workload)) + if re.match(TRIAL_DIR_REGEX, t) + ] + for trial in trial_dirs: + eval_measurements_filepath = os.path.join( + experiment_dir, + study_dir, + workload, + trial, + MEASUREMENTS_FILENAME, + ) + try: + trial_df = pd.read_csv(eval_measurements_filepath) + except FileNotFoundError as e: + logging.info(f'Could not read {eval_measurements_filepath}') + continue + data['trial'] = (trial, experiment_dir) + data['study'] = study_dir + for column in trial_df.columns: + values = trial_df[column].to_numpy() + data[column] = values + trial_df = pd.DataFrame([data]) + df = pd.concat([df, trial_df], ignore_index=True) return df From 84c87b94261a227c1784f82ce009fe4e1733e9a2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Jan 2024 01:29:37 +0000 Subject: [PATCH 40/71] remove debugging --- scoring/performance_profile.py | 1 - scoring/score_submissions.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 9c334ee22..ba0002d5d 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -152,7 +152,6 @@ def get_index_that_reaches_target(workload_df, else: index_reached = target_reached.apply(np.argmax) trial = index_reached.idxmin() - print(trial, index_reached[trial]) return trial, index_reached[trial] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 67e0317ae..866030c44 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -77,8 +77,6 @@ def main(_): experiment_path = os.path.join(FLAGS.submission_directory, submission) df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df - print('SUMMARY ') - print(df.keys()) print_submission_summary(df) if FLAGS.compute_performance_profiles: From d6e2a36db2391f00889f77bb2a25c10bac0998bf Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Jan 2024 01:49:50 +0000 Subject: [PATCH 41/71] formatting --- scoring/performance_profile.py | 7 +++---- scoring/scoring_utils.py | 3 ++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index ba0002d5d..6dc3f00d8 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -190,16 +190,16 @@ def get_times_for_submission(submission, if num_studies != NUM_STUDIES: if strict: raise ValueError(f'Expecting {NUM_STUDIES} trials for workload ' - f'{workload} but found {num_studies} trials.') + f'{workload} but found {num_studies} trials.') else: logging.warning(f'Expecting {NUM_STUDIES} trials for workload ' - f'{workload} but found {num_studies} trials.') + f'{workload} but found {num_studies} trials.') for study, group in group.groupby('study'): num_trials = len(group) if num_trials != NUM_TRIALS and not self_tuning_ruleset: if strict: raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') + f'{workload} but found {num_trials} trials.') else: logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' f'{workload} but found {num_trials} trials.') @@ -211,7 +211,6 @@ def get_times_for_submission(submission, else: time_val = float('inf') time_vals_per_study.append(time_val) - workloads.append({ 'submission': submission_name, diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 1fff9e6b5..b17b9c5bc 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -179,7 +179,8 @@ def get_experiment_df(experiment_dir): } logging.info(os.path.join(experiment_dir, study_dir, workload)) trial_dirs = [ - t for t in os.listdir(os.path.join(experiment_dir, study_dir, workload)) + t for t in os.listdir( + os.path.join(experiment_dir, study_dir, workload)) if re.match(TRIAL_DIR_REGEX, t) ] for trial in trial_dirs: From f34838a4ac4b1c8f39da35f24c06dc243e89055e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Jan 2024 02:08:28 +0000 Subject: [PATCH 42/71] documentation --- GETTING_STARTED.md | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 96a7b7d6f..eea06ba67 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -336,11 +336,44 @@ docker exec -it /bin/bash ``` ## Score your Submission +To score your submission we will score over all workloads, held-out workloads and studies as described in the rules. +In other words, the total number of runs expected for official scoring is: +- for external ruleset 8 (workloads) + 6 (held-out workloads) x 5 (studies) x 5 +- for internal ruleset 8 (workloads) + 6 (held-out workloads) x 5 (studies) -To produce performance profile and performance table: +You may have the time or compute resources to run all required runs, so our scoring scripts will allow some flexibility. + +### Running workloads +To run workloads for scoring you may specify a "virtual" list of held-out workloads. It is important +to note that the official set of held-out workloads will be sampled by the competition organizers during scoring time. + +An example config for held-out workloads is stored in `scoring/held_workloads_example.json`. +To generate a new sample of held out workloads run: + +```bash +python3 generate_held_out_workloads.py --seed --output_filename +``` + +To run a number of studies and trials over all workload using Docker containers for each run: + +```bash +python scoring/run_workloads.py \ +--framework \ +--experiment_name \ +--docker_image_url \ +--submission_path \ +--tuning_search_space \ +--held_out_workloads_config_path held_out_workloads_example.json \ +--num_studies +--seed +``` + +Note that to run the above script you will need the minimum jax_cpu and pytorch_cpu installations of the algorithmic-efficiency package. + +Finally to get the raw scores and performance profiles of group of submissions or single submission: ```bash -python3 scoring/score_submission.py --experiment_path= --output_dir= +python score_submissions.py --submission_directory --output_dir --compute_performance_profiles ``` We provide the scores and performance profiles for the [paper baseline algorithms](/reference_algorithms/paper_baselines/) in the "Baseline Results" section in [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179). From 84dbb075ca510221fa1f04a7ce0d79cdbff82ae0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 30 Jan 2024 20:29:47 +0000 Subject: [PATCH 43/71] fix --- scoring/performance_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 6dc3f00d8..6b49253d8 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -215,7 +215,7 @@ def get_times_for_submission(submission, workloads.append({ 'submission': submission_name, 'workload': workload, - time_col: np.median(time_val), + time_col: np.median(time_vals_per_study), }) if verbosity > 0: From 6d3b0aec04d8003cd9c37f8ce806edb4bb5e8c44 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 05:01:21 +0000 Subject: [PATCH 44/71] remove indexing for rng_subkeys --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 083dafb6a..af319e67b 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -135,7 +135,7 @@ def main(_): FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - rng_subkeys = prng.split(rng_key, num_studies)[:num_studies:] + rng_subkeys = prng.split(rng_key, num_studies) for study_index, rng_subkey in zip(range(study_start_index, study_end_index), rng_subkeys): print('-' * 100) From 7b23443349eb84d3e7c184b29f224a04666d2491 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 05:37:11 +0000 Subject: [PATCH 45/71] add documentation --- scoring/run_workloads.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index af319e67b..ec8b1f8ab 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -56,7 +56,7 @@ None, 'Start index for tuning trials.') flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') -flags.DEFINE_integer('seed', None, 'Random seed for scoring.') +flags.DEFINE_integer('seed', None, 'Random seed for evaluating a submission.') flags.DEFINE_integer('submission_id', 0, 'Submission ID to generate study and hparam seeds.') @@ -125,7 +125,7 @@ def main(_): rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) - rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), submission_id)) + rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id))) workloads = [w for w in WORKLOADS.keys()] @@ -145,7 +145,7 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: - rng_subkey, run_key = prng.split(rng_subkey) + run_key = prng.fold_in(rng_subkey, hash(workload)) run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() From 4b77dddaa0d221da85532e02820bd9c947f3aa7c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:03:04 +0000 Subject: [PATCH 46/71] fix documentation --- GETTING_STARTED.md | 6 ++++-- scoring/run_workloads.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index eea06ba67..48c23a1a6 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -337,9 +337,11 @@ docker exec -it /bin/bash ## Score your Submission To score your submission we will score over all workloads, held-out workloads and studies as described in the rules. +We will sample 1 held-out workload per dataset for a total of 6 held-out workloads and will use the sampled +held-out workloads in the scoring criteria for the matching base workloads. In other words, the total number of runs expected for official scoring is: -- for external ruleset 8 (workloads) + 6 (held-out workloads) x 5 (studies) x 5 -- for internal ruleset 8 (workloads) + 6 (held-out workloads) x 5 (studies) +- for external ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) x 5 (trials) +- for internal ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) You may have the time or compute resources to run all required runs, so our scoring scripts will allow some flexibility. diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index ec8b1f8ab..53d1aa2ee 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -63,10 +63,10 @@ flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') +flags.DEFINE_string('') FLAGS = flags.FLAGS -DATASETS = ['imagenet', 'fastmri', 'ogbg', 'wmt', 'librispeech', 'criteo1tb'] WORKLOADS = { 'imagenet_resnet': {'max_steps': 186_666, 'dataset': 'imagenet'}, From 2f480096a306db7aa9699c0762b268c41070813f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:13:39 +0000 Subject: [PATCH 47/71] add warning --- scoring/score_submissions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 866030c44..156d1b2f9 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -24,8 +24,9 @@ flags.DEFINE_boolean( 'strict', False, - 'Whether to enforce scoring criteria on variant' - 'performance and on 5-trial median performance') + 'Whether to enforce scoring criteria on variant performance and on' + '5-trial median performance. Note that during official scoring this ' + 'flag will be set to True.') flags.DEFINE_boolean( 'self_tuning_ruleset', False, @@ -78,6 +79,12 @@ def main(_): df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df print_submission_summary(df) + + if not FLAGS.strict: + logging.warning('You are running with strict=False. This will relax ' + 'scoring criteria on the held-out workloads, number of trials and number ' + 'of studies. Your score may not be an accurate representation ' + 'under competition scoring rules. To enforce the criteria set strict=True.') if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( From d39eb245cc54f13e42c676a99ef22c3fdd6c1346 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:16:26 +0000 Subject: [PATCH 48/71] typo --- scoring/performance_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 6b49253d8..d0351390b 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -166,7 +166,7 @@ def get_times_for_submission(submission, Args: submission: A DataFrame containing one row for each trial in each workload for a given submission. - submission_name: Globally unique identified for a submission. + submission_name: Globally unique identifier for a submission. time_col: A string indicating which column to use for time. verbosity: Debug level of information; choice of (1, 2, 3). From aecb37f93bf6c7df05f003acb525eb05d87d6071 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:18:54 +0000 Subject: [PATCH 49/71] fix documentation --- GETTING_STARTED.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 48c23a1a6..3fbb29ba5 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -343,7 +343,7 @@ In other words, the total number of runs expected for official scoring is: - for external ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) x 5 (trials) - for internal ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) -You may have the time or compute resources to run all required runs, so our scoring scripts will allow some flexibility. + ### Running workloads To run workloads for scoring you may specify a "virtual" list of held-out workloads. It is important @@ -372,7 +372,10 @@ python scoring/run_workloads.py \ Note that to run the above script you will need the minimum jax_cpu and pytorch_cpu installations of the algorithmic-efficiency package. -Finally to get the raw scores and performance profiles of group of submissions or single submission: +During submission development, it might be useful to do faster, approximate scoring (e.g. without 5 different s +tudies or when some trials are missing) so the scoring scripts allow some flexibility. To simulate official scoring, +pass the `--strict=True` flag in score_submission.py. To get the raw scores and performance profiles of group of +submissions or single submission: ```bash python score_submissions.py --submission_directory --output_dir --compute_performance_profiles From 5135cc85f73f72e5d416dd35d209be258f2aec59 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 31 Jan 2024 23:47:22 +0000 Subject: [PATCH 50/71] remove prng import from generate_held_out_workloads.py --- scoring/generate_held_out_workloads.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index c61e637bd..db449cb1f 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -1,14 +1,11 @@ import json import os +import numpy as np import struct from absl import app from absl import flags from absl import logging -import jax -import jax.numpy as jnp - -from algorithmic_efficiency import random_utils as prng flags.DEFINE_integer('held_out_workloads_seed', None, @@ -55,17 +52,14 @@ def main(_): rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) - rng_key = prng.PRNGKey(rng_seed) + rng = np.random.default_rng(rng_seed) sampled_held_out_workloads = [] for k, v in HELD_OUT_WORKLOADS.items(): - rng_key, rng_sub_key = prng.split(rng_key, 2) - p = jnp.array([1 / len(v) for w in v]) - sampled_index = jax.random.categorical(rng_sub_key, p) + sampled_index = rng.integers(len(v)) sampled_held_out_workloads.append(v[sampled_index]) logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}") - save_held_out_workloads(sampled_held_out_workloads, output_filename) From 6d4f82e8c81b2a9296b9aa99347ff4e400dbf62e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 00:27:46 +0000 Subject: [PATCH 51/71] fix technical documentation --- DOCUMENTATION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index a25f5b689..62b9cba0f 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -409,7 +409,7 @@ The held-out workloads function similarly to a holdout test set discouraging sub Modifications could, for example, include changing the number of layers or units (drawn from an interval), swapping the activation function (drawn from a set of applicable functions), or using different data augmentations (drawn from a list of possible pre-processing steps). The sample space should be wide enough to discourage submitters from simply trying them all out, but at the same time should be restricted enough to produce realistic workloads with acceptable achievable performances. -In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each fixed workload. +In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each dataset. Our scoring procedure uses the held-out workloads only to penalize submissions that can't handle the introduced modifications (see the [Scoring](#scoring) section for further details). From aaa1014cd156bbb94d4423281ebcd09944d4296c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 01:04:46 +0000 Subject: [PATCH 52/71] formatting --- scoring/run_workloads.py | 42 ++++++++++++++++++++---------------- scoring/score_submissions.py | 12 ++++++----- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 53d1aa2ee..bfd02c476 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -45,9 +45,11 @@ 'prize_qualification_baselines/external_tuning/tuning_search_space.json', 'Path to tuning search space.') flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean('dry_run', - False, - 'Whether or not to actually run the command') +flags.DEFINE_boolean( + 'dry_run', + False, + 'Whether or not to actually run the docker containers. ' + 'If False, simply print the docker run commands. ') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') flags.DEFINE_integer('study_end_index', None, 'End index for studies.') @@ -63,23 +65,20 @@ flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') -flags.DEFINE_string('') +flags.DEFINE_string( + 'workload_meta_data_config_path', + None, + 'Path to config containing dataset and maximum number of steps per workload.' + 'The default values of these are set to the full budgets as determined ' + 'via the target-setting procedure. ' + 'Note that training will be interrupted at either the set maximum number ' + 'of steps or the fixed workload maximum run time, whichever comes first. ' + 'If your algorithm has a smaller per step time than our baselines ' + 'you may want to increase the number of steps per workload.') FLAGS = flags.FLAGS -WORKLOADS = { - 'imagenet_resnet': {'max_steps': 186_666, 'dataset': 'imagenet'}, - 'imagenet_vit': {'max_steps': 186_666, 'dataset': 'imagenet'}, - 'fastmri': {'max_steps': 36_189, 'dataset': 'fastmri'}, - 'ogbg': {'max_steps': 80_000, 'dataset': 'ogbg'}, - 'wmt': {'max_steps': 133_333, 'dataset': 'wmt'}, - 'librispeech_deepspeech': {'max_steps': 48_000, 'dataset': 'librispeech'}, - 'criteo1tb': {'max_steps': 10_666, 'dataset': 'criteo1tb'}, - 'librispeech_conformer': {'max_steps': 80_000, 'dataset': 'librispeech'}, -} - - def read_held_out_workloads(filename): with open(filename, "r") as f: held_out_workloads = json.load(f) @@ -127,7 +126,10 @@ def main(_): logging.info('Using RNG seed %d', rng_seed) rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id))) - workloads = [w for w in WORKLOADS.keys()] + with open(FLAGS.workload_meta_data_config_path) as f: + workload_meta_data = json.load(f) + + workloads = [w for w in workload_meta_data.keys()] # Read held-out workloads if FLAGS.held_out_workloads_config_path: @@ -152,8 +154,9 @@ def main(_): os.system( "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('=' * 100) - dataset = WORKLOADS[base_workload_name]['dataset'] - max_steps = int(WORKLOADS[base_workload_name]['max_steps'] * run_fraction) + dataset = workload_meta_data[base_workload_name]['dataset'] + max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * + run_fraction) mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' @@ -202,5 +205,6 @@ def main(_): if __name__ == '__main__': + flags.mark_flag_as_required('workload_meta_data_config_path') app.run(main) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 156d1b2f9..aafc5530a 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -79,12 +79,14 @@ def main(_): df = scoring_utils.get_experiment_df(experiment_path) results[submission] = df print_submission_summary(df) - + if not FLAGS.strict: - logging.warning('You are running with strict=False. This will relax ' - 'scoring criteria on the held-out workloads, number of trials and number ' - 'of studies. Your score may not be an accurate representation ' - 'under competition scoring rules. To enforce the criteria set strict=True.') + logging.warning( + 'You are running with strict=False. This will relax ' + 'scoring criteria on the held-out workloads, number of trials and number ' + 'of studies. Your score may not be an accurate representation ' + 'under competition scoring rules. To enforce the criteria set strict=True.' + ) if FLAGS.compute_performance_profiles: performance_profile_df = performance_profile.compute_performance_profiles( From 761a877743ae06df055a7d4d7c8744eb28bb23d0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 01:06:18 +0000 Subject: [PATCH 53/71] add default for workload metadata config file --- scoring/run_workloads.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index bfd02c476..e9f76566f 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -45,11 +45,10 @@ 'prize_qualification_baselines/external_tuning/tuning_search_space.json', 'Path to tuning search space.') flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean( - 'dry_run', - False, - 'Whether or not to actually run the docker containers. ' - 'If False, simply print the docker run commands. ') +flags.DEFINE_boolean('dry_run', + False, + 'Whether or not to actually run the docker containers. ' + 'If False, simply print the docker run commands. ') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') flags.DEFINE_integer('study_end_index', None, 'End index for studies.') @@ -65,16 +64,15 @@ flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') -flags.DEFINE_string( - 'workload_meta_data_config_path', - None, - 'Path to config containing dataset and maximum number of steps per workload.' - 'The default values of these are set to the full budgets as determined ' - 'via the target-setting procedure. ' - 'Note that training will be interrupted at either the set maximum number ' - 'of steps or the fixed workload maximum run time, whichever comes first. ' - 'If your algorithm has a smaller per step time than our baselines ' - 'you may want to increase the number of steps per workload.') +flags.DEFINE_string('workload_meta_data_config_path', + 'workload_meta_data.json', + 'Path to config containing dataset and maximum number of steps per workload.' + 'The default values of these are set to the full budgets as determined ' + 'via the target-setting procedure. ' + 'Note that training will be interrupted at either the set maximum number ' + 'of steps or the fixed workload maximum run time, whichever comes first. ' + 'If your algorithm has a smaller per step time than our baselines ' + 'you may want to increase the number of steps per workload.') FLAGS = flags.FLAGS @@ -155,8 +153,7 @@ def main(_): "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('=' * 100) dataset = workload_meta_data[base_workload_name]['dataset'] - max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * - run_fraction) + max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * run_fraction) mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' @@ -205,6 +202,4 @@ def main(_): if __name__ == '__main__': - flags.mark_flag_as_required('workload_meta_data_config_path') - app.run(main) From 6b3827a4d63030cce79998688d6885170c4a9dab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 01:13:24 +0000 Subject: [PATCH 54/71] yapf fix --- scoring/run_workloads.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index e9f76566f..077ce8d4f 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -45,10 +45,11 @@ 'prize_qualification_baselines/external_tuning/tuning_search_space.json', 'Path to tuning search space.') flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean('dry_run', - False, - 'Whether or not to actually run the docker containers. ' - 'If False, simply print the docker run commands. ') +flags.DEFINE_boolean( + 'dry_run', + False, + 'Whether or not to actually run the docker containers. ' + 'If False, simply print the docker run commands. ') flags.DEFINE_integer('num_studies', 5, 'Number of studies to run') flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') flags.DEFINE_integer('study_end_index', None, 'End index for studies.') @@ -64,15 +65,16 @@ flags.DEFINE_string('held_out_workloads_config_path', None, 'Path to config containing held-out workloads') -flags.DEFINE_string('workload_meta_data_config_path', - 'workload_meta_data.json', - 'Path to config containing dataset and maximum number of steps per workload.' - 'The default values of these are set to the full budgets as determined ' - 'via the target-setting procedure. ' - 'Note that training will be interrupted at either the set maximum number ' - 'of steps or the fixed workload maximum run time, whichever comes first. ' - 'If your algorithm has a smaller per step time than our baselines ' - 'you may want to increase the number of steps per workload.') +flags.DEFINE_string( + 'workload_meta_data_config_path', + 'workload_meta_data.json', + 'Path to config containing dataset and maximum number of steps per workload.' + 'The default values of these are set to the full budgets as determined ' + 'via the target-setting procedure. ' + 'Note that training will be interrupted at either the set maximum number ' + 'of steps or the fixed workload maximum run time, whichever comes first. ' + 'If your algorithm has a smaller per step time than our baselines ' + 'you may want to increase the number of steps per workload.') FLAGS = flags.FLAGS @@ -153,7 +155,8 @@ def main(_): "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('=' * 100) dataset = workload_meta_data[base_workload_name]['dataset'] - max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * run_fraction) + max_steps = int(workload_meta_data[base_workload_name]['max_steps'] * + run_fraction) mount_repo_flag = '' if FLAGS.local: mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' From c0e1aad8aca0a7bc91d0673957847ef20f2158bf Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 1 Feb 2024 01:15:50 +0000 Subject: [PATCH 55/71] import order --- scoring/generate_held_out_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/generate_held_out_workloads.py b/scoring/generate_held_out_workloads.py index db449cb1f..474c4e7d7 100644 --- a/scoring/generate_held_out_workloads.py +++ b/scoring/generate_held_out_workloads.py @@ -1,11 +1,11 @@ import json import os -import numpy as np import struct from absl import app from absl import flags from absl import logging +import numpy as np flags.DEFINE_integer('held_out_workloads_seed', None, From c79433927a4f4d6024aba43cf00f78f2101a1ef5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 2 Feb 2024 01:46:51 +0000 Subject: [PATCH 56/71] fix fold_in in pytorch --- algorithmic_efficiency/random_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 68e9a9cfe..db97735cb 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -33,10 +33,12 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()]) -def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: - rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - return [new_seed, data] +def _fold_in(seed: SeedType, data: int) -> SeedType: + rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed)) + new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data)) + new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + return new_seed_1 + new_seed_2 def _split(seed: SeedType, num: int = 2) -> SeedType: @@ -58,7 +60,7 @@ def _check_jax_install() -> None: '--framework=pytorch to use the Numpy version instead.') -def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: +def fold_in(seed: SeedType, data: int) -> SeedType: if FLAGS.framework == 'jax': _check_jax_install() return jax_rng.fold_in(seed, data) From 247dcb0e2ccabe73b36bcfd72ab55473381958a7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 2 Feb 2024 22:56:34 +0000 Subject: [PATCH 57/71] random utils fixes --- algorithmic_efficiency/random_utils.py | 16 ++++++++++++++-- scoring/run_workloads.py | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index db97735cb..b4a26b5b0 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -36,14 +36,14 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: int) -> SeedType: rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed)) new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data)) + rng_2 = np.random.RandomState(seed=(_signed_to_unsigned(data) & 0xffffffff)) new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32) return new_seed_1 + new_seed_2 def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name @@ -60,6 +60,11 @@ def _check_jax_install() -> None: '--framework=pytorch to use the Numpy version instead.') +def _randint(seed: SeedType) -> int: + rng = np.random.RandomState(_signed_to_unsigned(seed)) + return rng.randint(MAX_INT32) + + def fold_in(seed: SeedType, data: int) -> SeedType: if FLAGS.framework == 'jax': _check_jax_install() @@ -79,3 +84,10 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name _check_jax_install() return jax_rng.PRNGKey(seed) return _PRNGKey(seed) + + +def randint(seed:SeedType) -> int: + if FLAGS.framework == 'jax': + _check_jax_install() + return jax_rng.randint(seed, ) + return _randint(seed) \ No newline at end of file diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 077ce8d4f..a464da341 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -148,7 +148,7 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: run_key = prng.fold_in(rng_subkey, hash(workload)) - run_seed = run_key[0] # arbitrary + run_seed = prng.randint(run_key) # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system( From 759c90df58ca91057fe750b310e4e746f5bf800c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 2 Feb 2024 23:06:17 +0000 Subject: [PATCH 58/71] remove indexing from rngs in pytorch workloads --- .../workloads/cifar/cifar_pytorch/workload.py | 4 ++-- .../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 2 +- .../workloads/fastmri/fastmri_pytorch/workload.py | 2 +- .../workloads/imagenet_resnet/imagenet_pytorch/workload.py | 4 ++-- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 2 +- .../librispeech_conformer/librispeech_pytorch/workload.py | 2 +- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 2 +- .../workloads/mnist/mnist_pytorch/workload.py | 2 +- .../workloads/ogbg/ogbg_pytorch/workload.py | 2 +- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py index af86c212e..e2d655e9b 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py @@ -65,7 +65,7 @@ def _build_dataset( } if split == 'eval_train': train_indices = indices_split['train'] - random.Random(data_rng[0]).shuffle(train_indices) + random.Random(data_rng).shuffle(train_indices) indices_split['eval_train'] = train_indices[:self.num_eval_train_examples] if split in indices_split: dataset = torch.utils.data.Subset(dataset, indices_split[split]) @@ -111,7 +111,7 @@ def init_model_fn( self._model.reset_parameters() return self._model, None - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) self._model = resnet18(num_classes=self._num_classes) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 85bb602d1..bff5fa837 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -72,7 +72,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """Only dropout is used.""" del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) # Disable cudnn benchmark to avoid OOM errors. torch.backends.cudnn.benchmark = False if self.use_resnet: diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 74f6aa13d..0ad1b3eeb 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -113,7 +113,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) model = UNet( num_pool_layers=self.num_pool_layers, num_channels=self.num_channels, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 6727054c9..ba2012644 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -103,7 +103,7 @@ def _build_dataset( if split == 'eval_train': indices = list(range(self.num_train_examples)) - random.Random(data_rng[0]).shuffle(indices) + random.Random(data_rng).shuffle(indices) dataset = torch.utils.data.Subset(dataset, indices[:self.num_eval_train_examples]) @@ -147,7 +147,7 @@ def init_model_fn( """Dropout is unused.""" del dropout_rate del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) if self.use_silu and self.use_gelu: raise RuntimeError('Cannot use both GELU and SiLU activations.') diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index e672e8d22..aec3f1aaf 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -30,7 +30,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 20f27b150..9f0a6f841 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -58,7 +58,7 @@ def init_model_fn( Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as input_dropout_rate. """ - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False torch.backends.cuda.enable_flash_sdp(False) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index bcdd78fb5..c968b528d 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -32,7 +32,7 @@ def init_model_fn( Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate as input_dropout_rate. """ - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) model = DeepspeechEncoderDecoder( DeepspeechConfig( feed_forward_dropout_rate=dropout_rate, diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index e638df078..a60e6040e 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -133,7 +133,7 @@ def init_model_fn( self._model.reset_parameters() return self._model, None - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) self._model = _Model() self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index d4817226d..84a445c4b 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -143,7 +143,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is unused.""" del aux_dropout_rate - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) model = GNN( num_outputs=self._num_outputs, dropout_rate=dropout_rate, diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 9f6d817f4..9ee959a4f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -171,7 +171,7 @@ def init_model_fn( dropout_rate: Optional[float] = None, aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" - torch.random.manual_seed(rng[0]) + torch.random.manual_seed(rng) if self.activation == 'relu': activation = F.relu From 6baacd7489de82704823a30716e4f1b24f51dc65 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 2 Feb 2024 23:07:30 +0000 Subject: [PATCH 59/71] formatting --- algorithmic_efficiency/random_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index b4a26b5b0..ce3f6b84d 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -86,8 +86,8 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name return _PRNGKey(seed) -def randint(seed:SeedType) -> int: +def randint(seed: SeedType) -> int: if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.randint(seed, ) - return _randint(seed) \ No newline at end of file + return jax_rng.randint(seed,) + return _randint(seed) From 368e348c1cb5aae149642979eab43d4abe8c4f9e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 2 Feb 2024 23:25:24 +0000 Subject: [PATCH 60/71] formatting --- algorithmic_efficiency/workloads/mnist/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..2b8202995 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -55,7 +55,7 @@ def _build_mnist_dataset( if shuffle: ds = ds.repeat() - ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0]) + ds = ds.shuffle(16 * global_batch_size, seed=data_rng) ds = ds.batch(global_batch_size, drop_remainder=is_train) if repeat_final_dataset: From 4be495f5e8712caac8a19c05d6e1e7cb77ed3b69 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 2 Feb 2024 23:47:10 +0000 Subject: [PATCH 61/71] fix seed shapes --- tests/test_param_shapes.py | 2 +- tests/test_param_types.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index b67625213..afa752cb5 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -113,7 +113,7 @@ def get_workload(workload): else: raise ValueError(f'Workload {workload} is not available.') _ = jax_workload.init_model_fn(jax.random.PRNGKey(0)) - _ = pytorch_workload.init_model_fn([0]) + _ = pytorch_workload.init_model_fn(0) return jax_workload, pytorch_workload diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 7cf8f63c3..639c7372d 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -221,7 +221,7 @@ def get_workload(workload_name): else: raise ValueError(f'Workload {workload_name} is not available.') _ = jax_workload.init_model_fn(jax.random.PRNGKey(0)) - _ = pytorch_workload.init_model_fn([0]) + _ = pytorch_workload.init_model_fn(0) return jax_workload, pytorch_workload From 5495d72cb84916cc60e05d492360ca53188fa6e7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 3 Feb 2024 00:05:38 +0000 Subject: [PATCH 62/71] fix dataset seed --- algorithmic_efficiency/workloads/mnist/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 2b8202995..834790bb0 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -55,7 +55,7 @@ def _build_mnist_dataset( if shuffle: ds = ds.repeat() - ds = ds.shuffle(16 * global_batch_size, seed=data_rng) + ds = ds.shuffle(16 * global_batch_size, seed=prng.randint(data_rng)) ds = ds.batch(global_batch_size, drop_remainder=is_train) if repeat_final_dataset: From 355ebe05902235331699891215684057a1c96e85 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 3 Feb 2024 00:20:37 +0000 Subject: [PATCH 63/71] fix rng utils --- algorithmic_efficiency/random_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index ce3f6b84d..4f2b922f2 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -89,5 +89,5 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name def randint(seed: SeedType) -> int: if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.randint(seed,) + return jax_rng.randint(seed, (0), 0, MAX_INT32) return _randint(seed) From 43c381aca98076d7d70195c40aa4c17326f321aa Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 3 Feb 2024 00:48:45 +0000 Subject: [PATCH 64/71] fix --- algorithmic_efficiency/random_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 4f2b922f2..caa660c2a 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -89,5 +89,5 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name def randint(seed: SeedType) -> int: if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.randint(seed, (0), 0, MAX_INT32) + return jax_rng.randint(seed, (), 0, MAX_INT32) return _randint(seed) From f81e877ee91e76479470eb96f5e1778fe43bbe3e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 3 Feb 2024 01:07:56 +0000 Subject: [PATCH 65/71] remove unused types from random_utils --- algorithmic_efficiency/random_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index caa660c2a..fa60ddef2 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -1,6 +1,6 @@ """Proxy functions in front of the Jax RNG API or a compatible Numpy RNG API.""" -from typing import Any, List, Union +from typing import Union from absl import flags from absl import logging @@ -36,7 +36,7 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: int) -> SeedType: rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed)) new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - rng_2 = np.random.RandomState(seed=(_signed_to_unsigned(data) & 0xffffffff)) + rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data) & 0xffffffff) new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32) return new_seed_1 + new_seed_2 From 3494f16c79784c27de1974ece9d0ff706496f09f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 3 Feb 2024 01:31:28 +0000 Subject: [PATCH 66/71] fix overflow error in jax sampling --- algorithmic_efficiency/random_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index fa60ddef2..0e83a705f 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -86,8 +86,8 @@ def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name return _PRNGKey(seed) -def randint(seed: SeedType) -> int: +def bits(seed: SeedType) -> int: if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.randint(seed, (), 0, MAX_INT32) + return jax_rng.bits(seed) return _randint(seed) From 665049ab2cca90bbb84c799d7f1313ec585c9790 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 3 Feb 2024 01:32:48 +0000 Subject: [PATCH 67/71] use bits instead of randint --- scoring/run_workloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index a464da341..83d7a5f65 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -148,7 +148,7 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: run_key = prng.fold_in(rng_subkey, hash(workload)) - run_seed = prng.randint(run_key) # arbitrary + run_seed = prng.bits(run_key) base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() os.system( From d0d3e3e997ef657661e89f5a0cd601ee567879af Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 3 Feb 2024 01:46:41 +0000 Subject: [PATCH 68/71] add numpy bits function --- algorithmic_efficiency/random_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 0e83a705f..f2b797453 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -60,9 +60,10 @@ def _check_jax_install() -> None: '--framework=pytorch to use the Numpy version instead.') -def _randint(seed: SeedType) -> int: +def _bits(seed: SeedType) -> int: rng = np.random.RandomState(_signed_to_unsigned(seed)) - return rng.randint(MAX_INT32) + b = rng.bytes(4) + return int.from_bytes(b, byteorder='little') def fold_in(seed: SeedType, data: int) -> SeedType: @@ -90,4 +91,4 @@ def bits(seed: SeedType) -> int: if FLAGS.framework == 'jax': _check_jax_install() return jax_rng.bits(seed) - return _randint(seed) + return _bits(seed) From e4476aa9d47c515ce9f18c55d28d891de2024a5a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 3 Feb 2024 02:04:43 +0000 Subject: [PATCH 69/71] change seed method call for mnist dataset --- algorithmic_efficiency/workloads/mnist/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 834790bb0..2b593948c 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -55,7 +55,7 @@ def _build_mnist_dataset( if shuffle: ds = ds.repeat() - ds = ds.shuffle(16 * global_batch_size, seed=prng.randint(data_rng)) + ds = ds.shuffle(16 * global_batch_size, seed=prng.bits(data_rng)) ds = ds.batch(global_batch_size, drop_remainder=is_train) if repeat_final_dataset: From 9ebdca77c299ef15adf99120dd373e7ddbfbf822 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 6 Feb 2024 01:09:27 +0000 Subject: [PATCH 70/71] add documentation --- algorithmic_efficiency/random_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index f2b797453..24196d7b1 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -36,6 +36,7 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: int) -> SeedType: rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed)) new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + # Truncate data to 32-bits, since numpy does not support 64-bit ints. rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data) & 0xffffffff) new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32) return new_seed_1 + new_seed_2 From f888a99ec60aea954ba3e9bed536ec836d711bd2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 6 Feb 2024 23:08:19 +0000 Subject: [PATCH 71/71] add fold in method --- algorithmic_efficiency/random_utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 24196d7b1..80bda1973 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -21,6 +21,10 @@ MAX_INT32 = 2**31 MIN_INT32 = -MAX_INT32 +# SALT constants +_SALT1 = np.random.RandomState(seed=5).randint(MIN_INT32, MAX_INT32, dtype=np.int32) +_SALT2 = np.random.RandomState(seed=6).randint(MIN_INT32, MAX_INT32, dtype=np.int32) + SeedType = Union[int, list, np.ndarray] @@ -33,13 +37,11 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: return np.array([s + 2**32 if s < 0 else s for s in seed.tolist()]) -def _fold_in(seed: SeedType, data: int) -> SeedType: - rng_1 = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed_1 = rng_1.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - # Truncate data to 32-bits, since numpy does not support 64-bit ints. - rng_2 = np.random.RandomState(seed=_signed_to_unsigned(data) & 0xffffffff) - new_seed_2 = rng_2.randint(MIN_INT32, MAX_INT32, dtype=np.int32) - return new_seed_1 + new_seed_2 +def _fold_in(seed, data, verbose = True): + a = np.random.RandomState(seed=_signed_to_unsigned(seed ^ _SALT1)).randint(MIN_INT32, MAX_INT32, dtype=np.int32) + b = np.random.RandomState(seed=_signed_to_unsigned(data ^ _SALT2)).randint(MIN_INT32, MAX_INT32, dtype=np.int32) + c = np.random.RandomState(seed=_signed_to_unsigned(a ^ b)).randint(MIN_INT32, MAX_INT32, dtype=np.int32) + return c def _split(seed: SeedType, num: int = 2) -> SeedType: