diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml index 7e6913c6b..3a0736fa2 100644 --- a/.github/workflows/regression_tests.yml +++ b/.github/workflows/regression_tests.yml @@ -6,131 +6,178 @@ on: - 'main' jobs: + build_and_push_jax_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=jax + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + build_and_push_pytorch_docker_image: + runs-on: self-hosted + steps: + - uses: actions/checkout@v2 + - name: Build and push docker images + run: | + GIT_BRANCH=${{ github.head_ref || github.ref_name }} + FRAMEWORK=pytorch + IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" + cd $HOME/algorithmic-efficiency/docker + docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH + BUILD_RETURN=$? + if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi + docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME + docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME fastmri_jax: runs-on: self-hosted + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev -d fastmri -f jax -s baselines/adamw/jax/submission.py -w fastmri -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fastmri -f jax -s baselines/adamw/jax/submission.py -w fastmri -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false imagenet_resnet_jax: runs-on: self-hosted + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev -d imagenet -f jax -s baselines/adamw/jax/submission.py -w imagenet_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s baselines/adamw/jax/submission.py -w imagenet_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false imagenet_vit_jax: runs-on: self-hosted + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev -d imagenet -f jax -s baselines/adamw/jax/submission.py -w imagenet_vit -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s baselines/adamw/jax/submission.py -w imagenet_vit -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false ogbg_jax: runs-on: self-hosted + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev -d ogbg -f jax -s baselines/adamw/jax/submission.py -w ogbg -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d ogbg -f jax -s baselines/adamw/jax/submission.py -w ogbg -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_jax: runs-on: self-hosted + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false librispeech_conformer_jax: runs-on: self-hosted + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_conformer -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_conformer -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false librispeech_deepspeech_jax: runs-on: self-hosted + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_deepspeech -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s baselines/adamw/jax/submission.py -w librispeech_deepspeech -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false wmt_jax: runs-on: self-hosted + needs: build_and_push_jax_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev -d wmt -f jax -s baselines/adamw/jax/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s baselines/adamw/jax/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false fastmri_pytorch: runs-on: self-hosted + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev -d fastmri -f pytorch -s baselines/adamw/pytorch/submission.py -w fastmri -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d fastmri -f pytorch -s baselines/adamw/pytorch/submission.py -w fastmri -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false imagenet_resnet_pytorch: runs-on: self-hosted + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev -d imagenet -f pytorch -s baselines/adamw/pytorch/submission.py -w imagenet_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s baselines/adamw/pytorch/submission.py -w imagenet_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false imagenet_vit_pytorch: runs-on: self-hosted + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev -d imagenet -f pytorch -s baselines/adamw/pytorch/submission.py -w imagenet_vit -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s baselines/adamw/pytorch/submission.py -w imagenet_vit -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false ogbg_pytorch: runs-on: self-hosted + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev -d ogbg -f pytorch -s baselines/adamw/pytorch/submission.py -w ogbg -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d ogbg -f pytorch -s baselines/adamw/pytorch/submission.py -w ogbg -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false criteo_pytorch: runs-on: self-hosted + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + exit $? librispeech_conformer_pytorch: runs-on: self-hosted + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev -d librispeech -f pytorch -s baselines/adamw/pytorch/submission.py -w librispeech_conformer -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s baselines/adamw/pytorch/submission.py -w librispeech_conformer -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false librispeech_deepspeech_pytorch: runs-on: self-hosted + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev -d librispeech -f pytorch -s baselines/adamw/pytorch/submission.py -w librispeech_deepspeech -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s baselines/adamw/pytorch/submission.py -w librispeech_deepspeech -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false wmt_pytorch: runs-on: self-hosted + needs: build_and_push_pytorch_docker_image steps: - uses: actions/checkout@v2 - name: Run containerized workload run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_dev -d wmt -f pytorch -s baselines/adamw/pytorch/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false + docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} + docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s baselines/adamw/pytorch/submission.py -w wmt -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false diff --git a/README.md b/README.md index 54e274a6c..216354e73 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ To use the Docker container as an interactive virtual environment, you can run a --gpus all \ --ipc=host \ - -keep_container_alive true + --keep_container_alive true ``` 2. Open a bash terminal ```bash @@ -148,8 +148,8 @@ python3 submission_runner.py \ --workload=mnist \ --experiment_dir=$HOME/experiments \ --experiment_name=my_first_experiment \ - --submission_path=reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py \ - --tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json + --submission_path=baselines/adamw/jax/submission.py \ + --tuning_search_space=baselines/adamw/tuning_search_space.json ``` **Pytorch** @@ -160,8 +160,8 @@ python3 submission_runner.py \ --workload=mnist \ --experiment_dir=$HOME/experiments \ --experiment_name=my_first_experiment \ - --submission_path=reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py \ - --tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json + --submission_path=baselines/adamw/jax/submission.py \ + --tuning_search_space=baselines/adamw/tuning_search_space.json ```
@@ -186,10 +186,10 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc submission_runner.py \ --framework=pytorch \ --workload=mnist \ - --experiment_dir=/home/znado \ + --experiment_dir=$HOME/experiments \ --experiment_name=baseline \ - --submission_path=reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py \ - --tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json \ + --submission_path=baselines/adamw/jax/submission.py \ + --tuning_search_space=baselines/adamw/tuning_search_space.json ```
diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py index f5875ac30..d47f1b484 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py @@ -18,7 +18,7 @@ def dot_interact(concat_features): """ batch_size = concat_features.shape[0] - # Interact features, select upper or lower-triangular portion, and re-shape. + # Interact features, select upper or lower-triangular portion, and reshape. xactions = jnp.matmul(concat_features, jnp.transpose(concat_features, [0, 2, 1])) feature_dim = xactions.shape[-1] @@ -46,7 +46,7 @@ class DlrmSmall(nn.Module): embed_dim: embedding dimension. """ - vocab_size: int = 32 * 128 * 1024 # 4_194_304 + vocab_size: int = 32 * 128 * 1024 # 4_194_304. num_dense_features: int = 13 mlp_bottom_dims: Sequence[int] = (512, 256, 128) mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index 8d7f4e2f9..ba8db9ced 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -1,4 +1,5 @@ """Criteo1TB workload implemented in Jax.""" + import functools from typing import Dict, Optional, Tuple diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py index 380401e9e..de6b4d1dd 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py @@ -6,33 +6,23 @@ from torch import nn -def dot_interact(concat_features): - """Performs feature interaction operation between dense or sparse features. - Input tensors represent dense or sparse features. - Pre-condition: The tensors have been stacked along dimension 1. - Args: - concat_features: Array of features with shape [B, n_features, feature_dim]. - Returns: - activations: Array representing interacted features. - """ - batch_size = concat_features.shape[0] - - # Interact features, select upper or lower-triangular portion, and re-shape. - xactions = torch.bmm(concat_features, - torch.permute(concat_features, (0, 2, 1))) - feature_dim = xactions.shape[-1] - - indices = torch.triu_indices(feature_dim, feature_dim) - num_elems = indices.shape[1] - indices = torch.tile(indices, [1, batch_size]) - indices0 = torch.reshape( - torch.tile( - torch.reshape(torch.arange(batch_size), [-1, 1]), [1, num_elems]), - [1, -1]) - indices = tuple(torch.cat((indices0, indices), 0)) - activations = xactions[indices] - activations = torch.reshape(activations, [batch_size, -1]) - return activations +class DotInteract(nn.Module): + """Performs feature interaction operation between dense or sparse features.""" + + def __init__(self, num_sparse_features): + super().__init__() + self.triu_indices = torch.triu_indices(num_sparse_features + 1, + num_sparse_features + 1) + + def forward(self, dense_features, sparse_features): + combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), + dim=1) + interactions = torch.bmm(combined_values, + torch.transpose(combined_values, 1, 2)) + interactions_flat = interactions[:, + self.triu_indices[0], + self.triu_indices[1]] + return torch.cat((dense_features, interactions_flat), dim=1) class DlrmSmall(nn.Module): @@ -62,13 +52,21 @@ def __init__(self, self.mlp_top_dims = mlp_top_dims self.embed_dim = embed_dim - self.embedding_table = nn.Embedding(self.vocab_size, self.embed_dim) - self.embedding_table.weight.data.uniform_(0, 1) - # Scale the initialization to fan_in for each slice. + # Ideally, we should use the pooled embedding implementation from + # `TorchRec`. However, in order to have identical implementation + # with that of Jax, we define a single embedding matrix. + num_chucks = 4 + assert vocab_size % num_chucks == 0 + self.embedding_table_chucks = [] scale = 1.0 / torch.sqrt(self.vocab_size) - self.embedding_table.weight.data = scale * self.embedding_table.weight.data + for i in range(num_chucks): + chunk = nn.Parameter( + torch.Tensor(self.vocab_size // num_chucks, self.embed_dim)) + chunk.data.uniform_(0, 1) + chunk.data = scale * chunk.data + self.register_parameter(f'embedding_chunk_{i}', chunk) + self.embedding_table_chucks.append(chunk) - # bottom mlp bottom_mlp_layers = [] input_dim = self.num_dense_features for dense_dim in self.mlp_bottom_dims: @@ -84,8 +82,9 @@ def __init__(self, 0., math.sqrt(1. / module.out_features)) - # top mlp - # TODO (JB): Write down the formula here instead of the constant. + self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,) + + # TODO: Write down the formula here instead of the constant. input_dims = 506 top_mlp_layers = [] num_layers_top = len(self.mlp_top_dims) @@ -110,19 +109,26 @@ def __init__(self, math.sqrt(1. / module.out_features)) def forward(self, x): - bot_mlp_input, cat_features = torch.split( + batch_size = x.shape[0] + + dense_features, sparse_features = torch.split( x, [self.num_dense_features, self.num_sparse_features], 1) - cat_features = cat_features.to(dtype=torch.int32) - bot_mlp_output = self.bot_mlp(bot_mlp_input) - batch_size = bot_mlp_output.shape[0] - feature_stack = torch.reshape(bot_mlp_output, - [batch_size, -1, self.embed_dim]) - idx_lookup = torch.reshape(cat_features, [-1]) % self.vocab_size - embed_features = self.embedding_table(idx_lookup) - embed_features = torch.reshape(embed_features, - [batch_size, -1, self.embed_dim]) - feature_stack = torch.cat([feature_stack, embed_features], axis=1) - dot_interact_output = dot_interact(concat_features=feature_stack) - top_mlp_input = torch.cat([bot_mlp_output, dot_interact_output], axis=-1) - logits = self.top_mlp(top_mlp_input) + + # Bottom MLP. + embedded_dense = self.bot_mlp(dense_features) + + # Sparse feature processing. + sparse_features = sparse_features.to(dtype=torch.int32) + idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size + embedding_table = torch.cat(self.embedding_table_chucks, dim=0) + embedded_sparse = embedding_table[idx_lookup] + embedded_sparse = torch.reshape(embedded_sparse, + [batch_size, -1, self.embed_dim]) + + # Dot product interactions. + concatenated_dense = self.dot_interact( + dense_features=embedded_dense, sparse_features=embedded_sparse) + + # Final MLP. + logits = self.top_mlp(concatenated_dense) return logits diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index e8c8df992..993d82c9d 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -1,8 +1,8 @@ """Criteo1TB workload implemented in PyTorch.""" + import contextlib -from typing import Dict, Optional, Tuple +from typing import Dict, Iterator, Optional, Tuple -import jax import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP @@ -22,7 +22,7 @@ class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): @property def eval_batch_size(self) -> int: - return 262_144 + return 32_768 def _per_example_sigmoid_binary_cross_entropy( self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor: @@ -66,11 +66,6 @@ def loss_fn( 'per_example': per_example_losses, } - def _eval_metric(self, logits: spec.Tensor, - targets: spec.Tensor) -> Dict[str, int]: - summed_loss = self.loss_fn(logits, targets)['summed'] - return {'loss': summed_loss} - def init_model_fn( self, rng: spec.RandomState, @@ -79,6 +74,8 @@ def init_model_fn( """Only dropout is used.""" del aux_dropout_rate torch.random.manual_seed(rng[0]) + # Disable cudnn benchmark to avoid OOM errors. + torch.backends.cudnn.benchmark = False model = DlrmSmall( vocab_size=self.vocab_size, num_dense_features=self.num_dense_features, @@ -130,25 +127,28 @@ def model_fn( return logits_batch, None - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + def _build_input_queue( + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: not_train = split != 'train' per_device_batch_size = int(global_batch_size / N_GPUS) # Only create and iterate over tf input pipeline in one Python process to # avoid creating too many threads. if RANK == 0: - np_iter = super()._build_input_queue(data_rng, - split, - data_dir, - global_batch_size, - num_batches, - repeat_final_dataset) + np_iter = super()._build_input_queue( + data_rng=data_rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, + repeat_final_dataset=repeat_final_dataset) weights = None while True: if RANK == 0: diff --git a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py b/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py index 17877d068..cb091b3a5 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py +++ b/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py @@ -5,6 +5,7 @@ validation). See here for the NVIDIA example: https://github.com/NVIDIA/DeepLearningExamples/blob/4e764dcd78732ebfe105fc05ea3dc359a54f6d5e/PyTorch/Recommendation/DLRM/preproc/run_spark_cpu.sh#L119. """ + import functools import os from typing import Optional diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 83ced02ce..801716de7 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -1,21 +1,24 @@ """Criteo1TB DLRM workload base class.""" + import math import os -from typing import Dict, Optional, Tuple +from typing import Dict, Iterator, Optional, Tuple -import jax +from absl import flags import torch.distributed as dist from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.criteo1tb import input_pipeline +FLAGS = flags.FLAGS + USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseCriteo1TbDlrmSmallWorkload(spec.Workload): """Criteo1tb workload.""" - vocab_size: int = 32 * 128 * 1024 # 4_194_304 + vocab_size: int = 32 * 128 * 1024 # 4_194_304. num_dense_features: int = 13 mlp_bottom_dims: Tuple[int, int] = (512, 256, 128) mlp_top_dims: Tuple[int, int, int] = (1024, 1024, 512, 256, 1) @@ -26,14 +29,15 @@ def target_metric_name(self) -> str: """The name of the target metric (useful for scoring/processing code).""" return 'loss' - def has_reached_validation_target(self, eval_result: float) -> bool: + def has_reached_validation_target(self, eval_result: Dict[str, + float]) -> bool: return eval_result['validation/loss'] < self.validation_target_value @property def validation_target_value(self) -> float: return 0.123649 - def has_reached_test_target(self, eval_result: float) -> bool: + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return eval_result['test/loss'] < self.test_target_value @property @@ -75,19 +79,22 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7703 # ~2 hours + return 7703 # ~2 hours. @property def eval_period_time_sec(self) -> int: - return 2 * 60 - - def _build_input_queue(self, - data_rng: jax.random.PRNGKey, - split: str, - data_dir: str, - global_batch_size: int, - num_batches: Optional[int] = None, - repeat_final_dataset: bool = False): + return 2 * 600 # 20 mins. + + def _build_input_queue( + self, + data_rng: spec.RandomState, + split: str, + data_dir: str, + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]: + del cache ds = input_pipeline.get_criteo1tb_dataset( split=split, shuffle_rng=data_rng, @@ -121,11 +128,11 @@ def _eval_model_on_split(self, if split not in self._eval_iters: # These iterators will repeat indefinitely. self._eval_iters[split] = self._build_input_queue( - rng, - split, - data_dir, - global_batch_size, - num_batches, + data_rng=rng, + split=split, + data_dir=data_dir, + global_batch_size=global_batch_size, + num_batches=num_batches, repeat_final_dataset=True) loss = 0.0 for _ in range(num_batches): diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index 6eed68f83..db7bdd7d1 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -281,7 +281,7 @@ def __init__(self, config: DeepspeechConfig): def forward(self, inputs, input_paddings): inputs = self.bn(inputs, input_paddings) - lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu() + lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( inputs, lengths, batch_first=True, enforce_sorted=False) packed_outputs, _ = self.lstm(packed_inputs) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 9fbc48578..0ce943b3b 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -11,27 +11,6 @@ from torch.nn.init import xavier_uniform_ -# Mask making utilities ported to PyTorch from -# https://github.com/google/flax/blob/main/flax/linen/attention.py. -def make_attention_mask(query_input: Tensor, - key_input: Tensor, - pairwise_fn: Callable[..., Any] = torch.mul, - dtype: torch.dtype = torch.float32) -> Tensor: - """Mask-making helper for attention weights. - - Args: - query_input: a batched, flat input of query_length size - key_input: a batched, flat input of key_length size - pairwise_fn: broadcasting elementwise comparison function - dtype: mask return dtype - - Returns: - A `[batch..., len_q, len_kv]` shaped attention mask. - """ - mask = pairwise_fn(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) - return mask.to(dtype) - - def make_causal_mask(x: Tensor, device: str = 'cuda:0', dtype: torch.dtype = torch.float32) -> Tensor: @@ -47,17 +26,21 @@ def make_causal_mask(x: Tensor, """ idxs = torch.broadcast_to( torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape) - return make_attention_mask(idxs, idxs, torch.greater_equal, dtype=dtype) + return torch.greater_equal(idxs.unsqueeze(-1), + idxs.unsqueeze(-2)).to(dtype=dtype) def make_src_mask(src, inputs_segmentation, nhead): """Utility for creating src mask and adjust it for PyTorch Transformer API.""" - src_mask = make_attention_mask(src > 0, src > 0) + src_mask = torch.mul((src > 0).unsqueeze(-1), + (src > 0).unsqueeze(-2)).to(dtype=torch.float32) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: src_mask = torch.logical_and( src_mask, - make_attention_mask(inputs_segmentation, inputs_segmentation, torch.eq)) + torch.eq( + inputs_segmentation.unsqueeze(-1), + inputs_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) # Flip values and ensure numerical stability. src_mask = torch.repeat_interleave( torch.logical_not(src_mask), repeats=nhead, dim=0) @@ -76,23 +59,27 @@ def make_tgt_and_memory_mask(tgt, Transformer API.""" if not decode: tgt_mask = torch.logical_and( - make_attention_mask(tgt > 0, tgt > 0), + torch.mul((tgt > 0).unsqueeze(-1), + (tgt > 0).unsqueeze(-2)).to(dtype=torch.float32), make_causal_mask(tgt, device=tgt.device)) - memory_mask = make_attention_mask(tgt > 0, src > 0) + memory_mask = torch.mul((tgt > 0).unsqueeze(-1), + (src > 0).unsqueeze(-2)).to(dtype=torch.float32) else: tgt_mask = None - memory_mask = make_attention_mask(torch.ones_like(tgt) > 0, src > 0) + memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1), + (src > 0).unsqueeze(-2)).to(dtype=torch.float32) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: tgt_mask = torch.logical_and( tgt_mask, - make_attention_mask(targets_segmentation, - targets_segmentation, - torch.eq)) + torch.eq( + targets_segmentation.unsqueeze(-1), + targets_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) memory_mask = torch.logical_and( memory_mask, - make_attention_mask(targets_segmentation, inputs_segmentation, - torch.eq)) + torch.eq( + targets_segmentation.unsqueeze(-1), + inputs_segmentation.unsqueeze(-2)).to(dtype=torch.float32)) # Flip values and ensure numerical stability. memory_mask = torch.repeat_interleave( torch.logical_not(memory_mask), repeats=nhead, dim=0) @@ -617,7 +604,7 @@ def forward(self, memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). memory_mask: the mask for the memory sequence (optional). - decode: wether to use cache for autoregressive decoding or not. + decode: whether to use cache for autoregressive decoding or not. max_len: maximum sequence length, necessary for decoding cache. Shape: see the docs in Transformer class. @@ -1214,7 +1201,7 @@ def multi_head_attention_forward(query: Tensor, else: assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ f'float, byte, and bool types are supported, not {attn_mask.dtype}' - # ensure attn_mask's dim is 3 + # Ensure attn_mask's dim is 3. if attn_mask.dim() == 2: correct_2d_size = (tgt_len, src_len) if attn_mask.shape != correct_2d_size: diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index f1131fc4e..ebc3d3f83 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -77,17 +77,27 @@ def predict_step(self, max_decode_len: int, beam_size: int = 4) -> spec.Tensor: """Predict translation with fast decoding beam search on a batch.""" - params = params.module if isinstance(params, (DP, DDP)) else params + # params = params.module if isinstance(params, (DP, DDP)) else params + if hasattr(params, 'module'): + params = params.module params.eval() - encoder = params.encoder + + if hasattr(params, '_modules'): + params = params._modules + encoder = params["encoder"] + decoder = params["decoder"] + else: + encoder = params.encoder + decoder = params.decoder + if N_GPUS > 1 and not USE_PYTORCH_DDP: encoder = DP(encoder) + if N_GPUS > 1 and not USE_PYTORCH_DDP: + decoder = DP(decoder) + encoded_inputs = torch.repeat_interleave( encoder(inputs), repeats=beam_size, dim=0) raw_inputs = torch.repeat_interleave(inputs, repeats=beam_size, dim=0) - decoder = params.decoder - if N_GPUS > 1 and not USE_PYTORCH_DDP: - decoder = DP(decoder) def tokens_ids_to_logits( flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor] diff --git a/baselines/adamw/pytorch/submission.py b/baselines/adamw/pytorch/submission.py index 6a086ff2d..75a4abbef 100644 --- a/baselines/adamw/pytorch/submission.py +++ b/baselines/adamw/pytorch/submission.py @@ -106,7 +106,7 @@ def update_params(workload: spec.Workload, optimizer_state['scheduler'].step() # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: + if global_step <= 10 or global_step % 500 == 0: with torch.no_grad(): parameters = [p for p in current_model.parameters() if p.grad is not None] grad_norm = torch.norm( diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 0227e728e..d1636a3e5 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -24,8 +24,8 @@ Note that some of the disk usage number below may be underestimates if the temp and final data dir locations are on the same drive. -Criteo download size: ~350GB -Criteo final disk size: ~1TB +Criteo 1TB download size: ~350GB +Criteo 1TB final disk size: ~1TB FastMRI download size: FastMRI final disk size: LibriSpeech download size: @@ -65,6 +65,16 @@ # pylint: disable=logging-format-interpolation # pylint: disable=consider-using-with +# isort: off +import tensorflow_datasets as tfds +from torchvision.datasets import CIFAR10 + +from algorithmic_efficiency.workloads.wmt import tokenizer +from algorithmic_efficiency.workloads.wmt.input_pipeline import \ + normalize_feature_names +from datasets import librispeech_preprocess +from datasets import librispeech_tokenizer + import functools import os import resource @@ -75,10 +85,12 @@ from absl import app from absl import flags from absl import logging +import re import requests -import tensorflow_datasets as tfds -from torchvision.datasets import CIFAR10 import tqdm +import urllib.parse + +import tensorflow as tf IMAGENET_TRAIN_TAR_FILENAME = 'ILSVRC2012_img_train.tar' IMAGENET_VAL_TAR_FILENAME = 'ILSVRC2012_img_val.tar' @@ -87,12 +99,6 @@ FASTMRI_VAL_TAR_FILENAME = 'knee_singlecoil_val.tar.xz' FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.xz' -from algorithmic_efficiency.workloads.wmt import tokenizer -from algorithmic_efficiency.workloads.wmt.input_pipeline import \ - normalize_feature_names -from datasets import librispeech_preprocess -from datasets import librispeech_tokenizer - flags.DEFINE_boolean( 'interactive_deletion', True, @@ -104,9 +110,9 @@ 'Whether or not to download all datasets. If false, can download some ' 'combination of datasets by setting the individual dataset flags below.') -flags.DEFINE_boolean('criteo', +flags.DEFINE_boolean('criteo1tb', False, - 'If --all=false, whether or not to download Criteo.') + 'If --all=false, whether or not to download Criteo 1TB.') flags.DEFINE_boolean('cifar', False, 'If --all=false, whether or not to download CIFAR-10.') @@ -207,7 +213,7 @@ def _download_url(url, data_dir, name=None): file_path = os.path.join(data_dir, url.split('/')[-1]) else: file_path = os.path.join(data_dir, name) - print(f"about to download to {file_path}") + logging.info(f'About to download to {file_path}') response = requests.get(url, stream=True, timeout=600) total_size_in_bytes = int(response.headers.get('Content-length', 0)) @@ -240,34 +246,78 @@ def _download_url(url, data_dir, name=None): url=url, n=progress_bar.n, size=progress_bar.total)) -def download_criteo(data_dir, - tmp_dir, - num_decompression_threads, - interactive_deletion): - criteo_dir = os.path.join(data_dir, 'criteo') - tmp_criteo_dir = os.path.join(tmp_dir, 'criteo') +def download_criteo1tb(data_dir, + tmp_dir, + num_decompression_threads, + interactive_deletion): + criteo_dir = os.path.join(data_dir, 'criteo1tb') + tmp_criteo_dir = os.path.join(tmp_dir, 'criteo1tb') _maybe_mkdir(criteo_dir) _maybe_mkdir(tmp_criteo_dir) + + # Forked from + # https://github.com/iamleot/transferwee/blob/master/transferwee.py. + user_agent = ('Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) ' + 'Gecko/20100101 Firefox/102.0') + criteo_wetransfer_url = ( + 'https://criteo.wetransfer.com/downloads/' + '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2') + _, _, transfer_id, security_hash = urllib.parse.urlparse( + criteo_wetransfer_url).path.split('/') + + session = requests.Session() + session.headers.update({ + 'User-Agent': user_agent, + 'x-requested-with': 'XMLHttpRequest', + }) + r = session.get('https://wetransfer.com/') + m = re.search('name="csrf-token" content="([^"]+)"', r.text) + if m: + session.headers.update({'x-csrf-token': m.group(1)}) + + get_url_request = session.post( + f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download', + json={ + 'intent': 'entire_transfer', + 'security_hash': security_hash, + }) + session.close() + + download_url = get_url_request.json().get('direct_link') + + logging.info(f'Downloading ~342GB Criteo 1TB data .zip file:\n{download_url}') + download_request = requests.get( # pylint: disable=missing-timeout + download_url, + headers={'User-Agent': user_agent}, + stream=True) + + all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip') + with open(all_days_zip_filepath, 'wb') as f: + for chunk in download_request.iter_content(chunk_size=1024): + f.write(chunk) + + unzip_cmd = f'unzip {all_days_zip_filepath} -d {tmp_criteo_dir}' + logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}') + p = subprocess.Popen(unzip_cmd, shell=True) + p.communicate() + _maybe_prompt_for_deletion(all_days_zip_filepath, interactive_deletion) + + # Unzip the individual days. processes = [] gz_paths = [] - # Download and unzip. for day in range(24): - logging.info(f'Downloading Criteo day {day}...') - wget_cmd = ( - f'wget --no-clobber --directory-prefix="{tmp_criteo_dir}" ' - f'https://sacriteopcail01.z16.web.core.windows.net/day_{day}.gz') input_path = os.path.join(tmp_criteo_dir, f'day_{day}.gz') gz_paths.append(input_path) unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') unzip_cmd = (f'pigz -d -c -p{num_decompression_threads} "{input_path}" > ' f'"{unzipped_path}"') - command_str = f'{wget_cmd} && {unzip_cmd}' - logging.info(f'Running Criteo download command:\n{command_str}') - processes.append(subprocess.Popen(command_str, shell=True)) + logging.info(f'Running Criteo unzip command for day {day}:\n{unzip_cmd}') + processes.append(subprocess.Popen(unzip_cmd, shell=True)) for p in processes: p.communicate() _maybe_prompt_for_deletion(gz_paths, interactive_deletion) - # Split into files with 1M lines each: day_1.csv -> day_1_[0-40].csv. + + # Split into files with 5M lines each: day_1.csv -> day_1_[0-39].csv. for batch in range(6): batch_processes = [] unzipped_paths = [] @@ -276,9 +326,9 @@ def download_criteo(data_dir, unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv') unzipped_paths.append(unzipped_path) split_path = os.path.join(criteo_dir, f'day_{day}_') - split_cmd = ('split -a 3 -d -l 1000000 --additional-suffix=.csv ' + split_cmd = ('split -a 3 -d -l 5000000 --additional-suffix=.csv ' f'"{unzipped_path}" "{split_path}"') - logging.info(f'Running Criteo split command:\n{split_cmd}') + logging.info(f'Running Criteo 1TB split command:\n{split_cmd}') batch_processes.append(subprocess.Popen(split_cmd, shell=True)) for p in batch_processes: p.communicate() @@ -296,16 +346,16 @@ def download_cifar(data_dir, framework): def extract_filename_from_url(url, start_str='knee', end_str='.xz'): - """ the url filenames are sometimes couched within a urldefense+aws access id etc. string. - unfortunately querying the content disposition in requests fails (not provided)... - so fast search is done here within the url - """ + """ The url filenames are sometimes couched within a urldefense+aws access id + etc. string. Unfortunately querying the content disposition in requests fails + (not provided)... so fast search is done here within the url. + """ failure = -1 start = url.find(start_str) end = url.find(end_str) if failure in (start, end): raise ValueError( - f"Unable to locate filename wrapped in {start}--{end} in {url}") + f'Unable to locate filename wrapped in {start}--{end} in {url}') end += len(end_str) # make it inclusive return url[start:end] @@ -342,9 +392,9 @@ def download_fastmri(data_dir, def extract(source, dest): if not os.path.exists(dest): os.path.makedirs(dest) - print(f"extracting {source} to {dest}") + logging.info(f'Extracting {source} to {dest}') tar = tarfile.open(source) - print(f"opened tar") + logging.info('Opened tar') tar.extractall(dest) tar.close() @@ -373,7 +423,7 @@ def setup_fastmri(data_dir, src_data_dir): logging.info('Unzipping {} to {}'.format(test_tar_file_path, test_data_dir)) extract(test_tar_file_path, test_data_dir) logging.info('Set up fastMRI dataset complete') - print(f"extraction completed! ") + logging.info('Extraction completed!') def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url): @@ -579,19 +629,19 @@ def main(_): data_dir = os.path.abspath(os.path.expanduser(data_dir)) logging.info('Downloading data to %s...', data_dir) - if FLAGS.all or FLAGS.criteo: - logging.info('Downloading criteo...') - download_criteo(data_dir, - tmp_dir, - num_decompression_threads, - FLAGS.interactive_deletion) + if FLAGS.all or FLAGS.criteo1tb: + logging.info('Downloading criteo1tb...') + download_criteo1tb(data_dir, + tmp_dir, + num_decompression_threads, + FLAGS.interactive_deletion) if FLAGS.all or FLAGS.mnist: logging.info('Downloading MNIST...') download_mnist(data_dir) if FLAGS.all or FLAGS.fastmri: - print(f"starting fastMRI download...\n") + logging.info('Starting fastMRI download...\n') logging.info('Downloading FastMRI...') knee_singlecoil_train_url = FLAGS.fastmri_knee_singlecoil_train_url knee_singlecoil_val_url = FLAGS.fastmri_knee_singlecoil_val_url @@ -600,8 +650,8 @@ def main(_): knee_singlecoil_val_url, knee_singlecoil_test_url): raise ValueError( - f'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url to ' - 'download the FastMRI dataset.\nSign up for the URLs at ' + 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url ' + 'to download the FastMRI dataset.\nSign up for the URLs at ' 'https://fastmri.med.nyu.edu/.') updated_data_dir = download_fastmri(data_dir, @@ -609,7 +659,7 @@ def main(_): knee_singlecoil_val_url, knee_singlecoil_test_url) - print(f"fastMRI download completed. Extracting...") + logging.info('fastMRI download completed. Extracting...') setup_fastmri(data_dir, updated_data_dir) if FLAGS.all or FLAGS.imagenet: diff --git a/docker/Dockerfile b/docker/Dockerfile index d178d6bf1..ab6a798c1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -14,6 +14,8 @@ RUN apt-get update RUN apt-get install -y curl tar RUN apt-get install -y git python3 pip wget ffmpeg RUN apt-get install libtcmalloc-minimal4 +RUN apt-get install unzip +RUN apt-get install pigz RUN export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4 # Install GCP tools diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index f3c891c6f..9e0e68ca9 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -1,7 +1,11 @@ +#!/bin/bash # Bash script to build and push dev docker images to artifact repo # Usage: # bash build_docker_images.sh -b +# Make program exit with non-zero exit code if any command fails. +set -e + while getopts b: flag do case "${flag}" in @@ -31,4 +35,4 @@ do eval $DOCKER_PUSH_COMMAND echo "To pull container run: " echo $DOCKER_PULL_COMMAND -done \ No newline at end of file +done diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index cdd2c649c..410d21532 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -222,4 +222,5 @@ then done fi +echo "Exiting with $RETURN_CODE" exit $RETURN_CODE diff --git a/submission_runner.py b/submission_runner.py index 1850c598e..f4ee32ede 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -15,6 +15,7 @@ """ import datetime +import gc import importlib import json import os @@ -154,6 +155,14 @@ def _get_time_ddp(): get_time = _get_time +def _reset_cuda_mem(): + if FLAGS.framework == 'pytorch' and torch.cuda.is_available(): + torch._C._cuda_clearCublasWorkspaces() + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + def train_once( workload: spec.Workload, global_batch_size: int, @@ -191,9 +200,11 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = ['ogbg', 'librispeech_deepspeech', 'wmt'] - eager_backend_workloads = ['librispeech_conformer'] - aot_eager_backend_workloads = ['criteo1tb'] + compile_error_workloads = ['ogbg', 'criteo1tb'] + eager_backend_workloads = [ + 'librispeech_conformer', 'librispeech_deepspeech' + ] + aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: logging.warning( 'These workloads cannot be fully compiled under current ' @@ -322,6 +333,9 @@ def train_once( if ((train_step_end_time - train_state['last_eval_time']) >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): + del batch + _reset_cuda_mem() + try: eval_start_time = get_time() latest_eval_result = workload.eval_model(global_eval_batch_size, @@ -331,7 +345,7 @@ def train_once( data_dir, imagenet_v2_data_dir, global_step) - # Check if targets reached + # Check if targets reached. train_state['validation_goal_reached'] = ( workload.has_reached_validation_target(latest_eval_result) or train_state['validation_goal_reached']) @@ -339,15 +353,15 @@ def train_once( workload.has_reached_test_target(latest_eval_result) or train_state['test_goal_reached']) - # Save last eval time + # Save last eval time. eval_end_time = get_time() train_state['last_eval_time'] = eval_end_time - # Accumulate eval time + # Accumulate eval time. train_state[ 'accumulated_eval_time'] += eval_end_time - eval_start_time - # Add times to eval results for logging + # Add times to eval results for logging. latest_eval_result['score'] = ( train_state['accumulated_submission_time']) latest_eval_result[ @@ -386,20 +400,18 @@ def train_once( save_intermediate_checkpoints=FLAGS .save_intermediate_checkpoints) - if FLAGS.framework == 'pytorch' and torch.cuda.is_available(): - torch.cuda.empty_cache() logging_end_time = get_time() - train_state['accumulated_logging_time'] += ( logging_end_time - logging_start_time) + _reset_cuda_mem() + except RuntimeError as e: logging.exception(f'Eval step {global_step} error.\n') if 'out of memory' in str(e): logging.warning('Error: GPU out of memory during eval during step ' f'{global_step}, error : {str(e)}.') - if torch.cuda.is_available(): - torch.cuda.empty_cache() + _reset_cuda_mem() train_state['last_step_end_time'] = get_time() diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index fef9c2978..5b33d8b62 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -1,3 +1,5 @@ +from itertools import zip_longest + import jax import numpy as np import pytest @@ -53,13 +55,21 @@ def test_param_shapes(workload): jax_workload.param_shapes.unfreeze()) pytorch_param_shapes = jax.tree_util.tree_leaves( pytorch_workload.param_shapes) - assert len(jax_param_shapes) == len(pytorch_param_shapes) + if workload == 'criteo1tb': + # The PyTorch implementation divides the embedding matrix + # into 3 chunks. + assert len(jax_param_shapes) == len(pytorch_param_shapes) - 3 + else: + assert len(jax_param_shapes) == len(pytorch_param_shapes) # Check if total number of params deduced from shapes match. num_jax_params = 0 num_pytorch_params = 0 - for jax_shape, pytorch_shape in zip(jax_param_shapes, pytorch_param_shapes): - num_jax_params += np.prod(jax_shape.shape_tuple) - num_pytorch_params += np.prod(pytorch_shape.shape_tuple) + for jax_shape, pytorch_shape in zip_longest(jax_param_shapes, + pytorch_param_shapes): + if jax_shape is not None: + num_jax_params += np.prod(jax_shape.shape_tuple) + if pytorch_shape is not None: + num_pytorch_params += np.prod(pytorch_shape.shape_tuple) assert num_jax_params == num_pytorch_params diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 3679289ed..45e855759 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -129,6 +129,11 @@ def test_param_types(workload_name): jax_param_types_dict = count_param_types(jax_param_types) pytorch_param_types_dict = count_param_types(pytorch_param_types) + # PyTorch splits embedding matrix into 3 chunks. + if workload_name == 'criteo1tb': + pytorch_param_types_dict[spec.ParameterType.WEIGHT] -= 4 + pytorch_param_types_dict[spec.ParameterType.EMBEDDING] = 1 + # Jax fuses LSTM cells together, whereas PyTorch exposes all the weight # parameters, and there are two per cell, for each of the forward and backward # directional LSTMs, and there are 6 layers of LSTM in librispeech_deepspeech,