diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml new file mode 100644 index 000000000..5bb84c867 --- /dev/null +++ b/.github/workflows/traindiffs_tests.yml @@ -0,0 +1,32 @@ +name: Containerized training differences tests between Jax and PyTorch + +on: + pull_request: + branches: + - 'main' + +jobs: + build_and_push_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker image + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=both + IMAGE_NAME="algoperf_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + traindiffs_tests: + runs-on: self-hosted + needs: build_and_push_docker_image + steps: + - uses: actions/checkout@v2 + - name: Run containerized traindiffs test + run: | + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} algoperf_${{ github.head_ref || github.ref_name }} --traindiffs_test true diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 30cb6b36b..0d9ad1b1e 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -14,8 +14,8 @@ function usage() { $0 [--dataset dataset] [--framework framework] [--submission_path submission_path] [--tuning_search_space tuning_search_space] [--experiment_name experiment_name] [--workload workload] [--max_global_steps max_global_steps] [--rsync_data rsync_data] - [--internal_contributor true] - + [--internal_contributor true] [--traindiffs_test false] + Options: -d | --dataset: Can be imagenet, criteo1tb, ogbg, fastmri, wmt, librispeech. -f | --framework: Can be jax or pytorch. @@ -34,11 +34,13 @@ function usage() { from internal GCP bucket. -i | --internal_contributor: If true, allow rsync of data and transfer of experiment results with GCP project. + --traindiffs_test: If true, ignore all other options and run the traindiffs test. USAGE exit 1 } # Defaults +TEST="false" INTERNAL_CONTRIBUTOR_MODE="false" HOME_DIR="" RSYNC_DATA="true" @@ -47,7 +49,11 @@ SAVE_CHECKPOINTS="true" # Pass flag while [ "$1" != "" ]; do - case $1 in + case $1 in + --traindiffs_test) + shift + TEST=$1 + ;; -d | --dataset) shift DATASET=$1 @@ -106,8 +112,15 @@ while [ "$1" != "" ]; do ;; esac shift -done - +done + +if [[ ${TEST} == "true" ]]; then + cd algorithmic-efficiency + COMMAND="python3 tests/test_traindiffs.py" + echo $COMMAND + eval $COMMAND + exit +fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 9a95f3656..adbade983 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallWorkload as PytWorkload + Criteo1TbDlrmSmallWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -51,7 +51,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index 719484037..0748e2d71 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallEmbedInitWorkload as PytWorkload + Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -50,7 +50,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 3fc2a750a..0a6e5c5ac 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -11,7 +11,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallLayerNormWorkload as PytWorkload + Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index b9dbbc80e..288442594 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -12,7 +12,7 @@ from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallResNetWorkload as JaxWorkload from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ - Criteo1TbDlrmSmallResNetWorkload as PytWorkload + Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() pyt_batch = { 'inputs': torch.ones((2, 13 + 26)), diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index 6780ff91e..56b74b32d 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRIWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIWorkload as PytWorkload + FastMRIWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 4be086da3..23ccf26d7 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRILayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRILayerNormWorkload as PytWorkload + FastMRILayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -62,7 +62,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index 60d846b6f..b61516c29 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRIModelSizeWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRIModelSizeWorkload as PytWorkload + FastMRIModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 47bad372a..0f455387c 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ FastMRITanhWorkload as JaxWorkload from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ - FastMRITanhWorkload as PytWorkload + FastMRITanhWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -55,7 +55,7 @@ def sort_key(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 320, 320) diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index 2fc721ab0..fb730f1bf 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetWorkload as PytWorkload + ImagenetResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -72,7 +72,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 8c3899076..6c8adbec2 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetGELUWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetGELUWorkload as PytWorkload + ImagenetResNetGELUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -19,7 +19,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index ee74e7bc9..7668cdbd9 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetSiLUWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ - ImagenetResNetSiLUWorkload as PytWorkload + ImagenetResNetSiLUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -19,7 +19,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index e0c2506a6..ba21e63da 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitWorkload as PytWorkload + ImagenetVitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -85,7 +85,7 @@ def key_transform(k): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py index 444f1230a..2c0aa546d 100644 --- a/tests/modeldiffs/imagenet_vit_glu/compare.py +++ b/tests/modeldiffs/imagenet_vit_glu/compare.py @@ -13,7 +13,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitGluWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitGluWorkload as PytWorkload + ImagenetVitGluWorkload as PyTorchWorkload sd_transform = None @@ -21,7 +21,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py index e73a140f5..8a9063cac 100644 --- a/tests/modeldiffs/imagenet_vit_postln/compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -13,7 +13,7 @@ from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitPostLNWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitPostLNWorkload as PytWorkload + ImagenetViTPostLNWorkload as PyTorchWorkload sd_transform = None @@ -21,7 +21,7 @@ # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. image = torch.randn(2, 3, 224, 224) diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index d414001dd..cfe6c7381 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload as PytWorkload + LibriSpeechConformerWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index 64612fbf0..8480fca02 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerAttentionTemperatureWorkload as PytWorkload + LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index 892040b57..caa9b09b9 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerGeluWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerGeluWorkload as PytWorkload + LibriSpeechConformerGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index 784fceb60..1a94d3c77 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerLayerNormWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerLayerNormWorkload as PytWorkload + LibriSpeechConformerLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -58,7 +58,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index 12b79a517..edcc3ba87 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechWorkload as JaxWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ - LibriSpeechDeepSpeechWorkload as PytWorkload + LibriSpeechDeepSpeechWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -83,7 +83,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. wave = torch.randn(2, 320000) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 8d0ee8411..41fc5ee17 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkload as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkload as PytWorkload + WmtWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index b50abd3ca..92ce4eb44 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadAttentionTemp as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadAttentionTemp as PytWorkload + WmtWorkloadAttentionTemp as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 1322ad0a0..b8d860479 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadGLUTanH as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadGLUTanH as PytWorkload + WmtWorkloadGLUTanH as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index bfd701736..3f5469d8d 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -10,7 +10,7 @@ from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ WmtWorkloadPostLN as JaxWorkload from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ - WmtWorkloadPostLN as PytWorkload + WmtWorkloadPostLN as PyTorchWorkload from tests.modeldiffs.diff import out_diff @@ -106,7 +106,7 @@ def sd_transform(sd): # pylint: disable=locally-disabled, not-callable jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() + pytorch_workload = PyTorchWorkload() # Test outputs for identical weights and inputs. inp_tokens = torch.randint(low=0, high=32000, size=(2, 256)) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 5c43b233b..74c06e180 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -197,7 +197,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None): f'tests.modeldiffs.{workload_name}.compare') jax_params, model_state, _ = diff_utils.torch2jax( jax_workload=super(), - pytorch_workload=compare_module.PytWorkload(**self.init_kwargs), + pytorch_workload=compare_module.PyTorchWorkload(**self.init_kwargs), key_transform=compare_module.key_transform, sd_transform=compare_module.sd_transform) return (FrozenDict(**jax_utils.replicate(jax_params)), diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index a1b64a573..663cf3de4 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -10,6 +10,7 @@ from absl import flags from absl.testing import absltest +from numpy import allclose FLAGS = flags.FLAGS @@ -81,6 +82,23 @@ def test_workload(self): print(header) print('=' * len(header)) for i in range(NUM_TRAIN_STEPS): + rtol = 1e-1 if workload == 'librispeech_deepspeech' else 5e-3 + self.assertTrue( + allclose( + jax_results['eval_results'][i][k], + pyt_results['eval_results'][i][k], + rtol=rtol)) + self.assertTrue( + allclose( + jax_results['scalars'][i]['grad_norm'], + pyt_results['scalars'][i]['grad_norm'], + rtol=rtol)) + self.assertTrue( + allclose( + jax_results['scalars'][i]['loss'], + pyt_results['scalars'][i]['loss'], + rtol=rtol)) + row = map(lambda x: str(round(x, 5)), [ jax_results['eval_results'][i][k],