Skip to content

Commit

Permalink
Merge pull request #613 from runame/traindiffs
Browse files Browse the repository at this point in the history
Add traindiffs tests to workflows
  • Loading branch information
priyakasimbeg authored Jan 25, 2024
2 parents db055a6 + 047475c commit 8108c00
Show file tree
Hide file tree
Showing 27 changed files with 115 additions and 52 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/traindiffs_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Containerized training differences tests between Jax and PyTorch

on:
pull_request:
branches:
- 'main'

jobs:
build_and_push_docker_image:
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- name: Build and push docker image
run: |
GIT_BRANCH=${{ github.head_ref || github.ref_name }}
FRAMEWORK=both
IMAGE_NAME="algoperf_${GIT_BRANCH}"
cd $HOME/algorithmic-efficiency/docker
docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH
BUILD_RETURN=$?
if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi
docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
traindiffs_tests:
runs-on: self-hosted
needs: build_and_push_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized traindiffs test
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_${{ github.head_ref || github.ref_name }} algoperf_${{ github.head_ref || github.ref_name }} --traindiffs_test true
23 changes: 18 additions & 5 deletions docker/scripts/startup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ function usage() {
$0 [--dataset dataset] [--framework framework] [--submission_path submission_path]
[--tuning_search_space tuning_search_space] [--experiment_name experiment_name]
[--workload workload] [--max_global_steps max_global_steps] [--rsync_data rsync_data]
[--internal_contributor true]
[--internal_contributor true] [--traindiffs_test false]
Options:
-d | --dataset: Can be imagenet, criteo1tb, ogbg, fastmri, wmt, librispeech.
-f | --framework: Can be jax or pytorch.
Expand All @@ -34,11 +34,13 @@ function usage() {
from internal GCP bucket.
-i | --internal_contributor: If true, allow rsync of data and transfer of experiment results
with GCP project.
--traindiffs_test: If true, ignore all other options and run the traindiffs test.
USAGE
exit 1
}

# Defaults
TEST="false"
INTERNAL_CONTRIBUTOR_MODE="false"
HOME_DIR=""
RSYNC_DATA="true"
Expand All @@ -47,7 +49,11 @@ SAVE_CHECKPOINTS="true"

# Pass flag
while [ "$1" != "" ]; do
case $1 in
case $1 in
--traindiffs_test)
shift
TEST=$1
;;
-d | --dataset)
shift
DATASET=$1
Expand Down Expand Up @@ -106,8 +112,15 @@ while [ "$1" != "" ]; do
;;
esac
shift
done

done

if [[ ${TEST} == "true" ]]; then
cd algorithmic-efficiency
COMMAND="python3 tests/test_traindiffs.py"
echo $COMMAND
eval $COMMAND
exit
fi

# Check if arguments are valid
VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/criteo1tb/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \
Criteo1TbDlrmSmallWorkload as JaxWorkload
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallWorkload as PytWorkload
Criteo1TbDlrmSmallWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -51,7 +51,7 @@ def sd_transform(sd):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

pyt_batch = {
'inputs': torch.ones((2, 13 + 26)),
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/criteo1tb_embed_init/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \
Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallEmbedInitWorkload as PytWorkload
Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -50,7 +50,7 @@ def sd_transform(sd):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

pyt_batch = {
'inputs': torch.ones((2, 13 + 26)),
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/criteo1tb_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \
Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallLayerNormWorkload as PytWorkload
Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -62,7 +62,7 @@ def sd_transform(sd):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

pyt_batch = {
'inputs': torch.ones((2, 13 + 26)),
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/criteo1tb_resnet/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \
Criteo1TbDlrmSmallResNetWorkload as JaxWorkload
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallResNetWorkload as PytWorkload
Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -62,7 +62,7 @@ def sd_transform(sd):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

pyt_batch = {
'inputs': torch.ones((2, 13 + 26)),
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/fastmri/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \
FastMRIWorkload as JaxWorkload
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \
FastMRIWorkload as PytWorkload
FastMRIWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -55,7 +55,7 @@ def sort_key(k):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 320, 320)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/fastmri_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \
FastMRILayerNormWorkload as JaxWorkload
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \
FastMRILayerNormWorkload as PytWorkload
FastMRILayerNormWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -62,7 +62,7 @@ def sort_key(k):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 320, 320)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/fastmri_model_size/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \
FastMRIModelSizeWorkload as JaxWorkload
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \
FastMRIModelSizeWorkload as PytWorkload
FastMRIModelSizeWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -55,7 +55,7 @@ def sort_key(k):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 320, 320)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/fastmri_tanh/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \
FastMRITanhWorkload as JaxWorkload
from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \
FastMRITanhWorkload as PytWorkload
FastMRITanhWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -55,7 +55,7 @@ def sort_key(k):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 320, 320)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/imagenet_resnet/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \
ImagenetResNetWorkload as JaxWorkload
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \
ImagenetResNetWorkload as PytWorkload
ImagenetResNetWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -72,7 +72,7 @@ def sd_transform(sd):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 3, 224, 224)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/imagenet_resnet/gelu_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \
ImagenetResNetGELUWorkload as JaxWorkload
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \
ImagenetResNetGELUWorkload as PytWorkload
ImagenetResNetGELUWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.imagenet_resnet.compare import key_transform
from tests.modeldiffs.imagenet_resnet.compare import sd_transform
Expand All @@ -19,7 +19,7 @@
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 3, 224, 224)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/imagenet_resnet/silu_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \
ImagenetResNetSiLUWorkload as JaxWorkload
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \
ImagenetResNetSiLUWorkload as PytWorkload
ImagenetResNetSiLUWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.imagenet_resnet.compare import key_transform
from tests.modeldiffs.imagenet_resnet.compare import sd_transform
Expand All @@ -19,7 +19,7 @@
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 3, 224, 224)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/imagenet_vit/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \
ImagenetVitWorkload as JaxWorkload
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \
ImagenetVitWorkload as PytWorkload
ImagenetVitWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -85,7 +85,7 @@ def key_transform(k):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 3, 224, 224)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/imagenet_vit_glu/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \
ImagenetVitGluWorkload as JaxWorkload
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \
ImagenetVitGluWorkload as PytWorkload
ImagenetVitGluWorkload as PyTorchWorkload

sd_transform = None

if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 3, 224, 224)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/imagenet_vit_postln/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \
ImagenetVitPostLNWorkload as JaxWorkload
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \
ImagenetVitPostLNWorkload as PytWorkload
ImagenetViTPostLNWorkload as PyTorchWorkload

sd_transform = None

if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
image = torch.randn(2, 3, 224, 224)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/librispeech_conformer/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \
LibriSpeechConformerWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \
LibriSpeechConformerWorkload as PytWorkload
LibriSpeechConformerWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -58,7 +58,7 @@ def sd_transform(sd):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \
LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \
LibriSpeechConformerAttentionTemperatureWorkload as PytWorkload
LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -58,7 +58,7 @@ def sd_transform(sd):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
Expand Down
4 changes: 2 additions & 2 deletions tests/modeldiffs/librispeech_conformer_gelu/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \
LibriSpeechConformerGeluWorkload as JaxWorkload
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \
LibriSpeechConformerGeluWorkload as PytWorkload
LibriSpeechConformerGeluWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff


Expand Down Expand Up @@ -58,7 +58,7 @@ def sd_transform(sd):
# pylint: disable=locally-disabled, not-callable

jax_workload = JaxWorkload()
pytorch_workload = PytWorkload()
pytorch_workload = PyTorchWorkload()

# Test outputs for identical weights and inputs.
wave = torch.randn(2, 320000)
Expand Down
Loading

0 comments on commit 8108c00

Please sign in to comment.