From fa48558a43f608545e3424103cc2c7337a49041e Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 23 Oct 2023 08:37:53 +0000 Subject: [PATCH 01/29] wip Co-authored-by: Terry Kong --- .github/container/Dockerfile.jax | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index 28ef709d2..f0f2759e8 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -68,31 +68,15 @@ COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX} COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA} ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/ -RUN mkdir -p /opt/pip-tools.d -RUN <<"EOF" bash -ex -echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/manifest.jax -echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/manifest.jax -EOF - -## Flax -ARG REPO_FLAX -ARG REF_FLAX -ARG SRC_PATH_FLAX -RUN get-source.sh -f ${REPO_FLAX} -r ${REF_FLAX} -d ${SRC_PATH_FLAX} -m /opt/pip-tools.d/manifest.flax +COPY --from=jax-builder ${SRC_PATH_JAX} ${SRC_PATH_JAX} +COPY --from=jax-builder ${SRC_PATH_XLA} ${SRC_PATH_XLA} +ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/ -## Transformer engine: check out source and build wheel -ARG REPO_TE -ARG REF_TE -ARG SRC_PATH_TE -ENV NVTE_FRAMEWORK=jax -ENV SRC_PATH_TE=${SRC_PATH_TE} -RUN <<"EOF" bash -ex -set -o pipefail -pip install ninja && rm -rf ~/.cache/pip -get-source.sh -f ${REPO_TE} -r ${REF_TE} -d ${SRC_PATH_TE} -pushd ${SRC_PATH_TE} -python setup.py bdist_wheel && rm -rf build -echo "transformer-engine @ file://$(ls ${SRC_PATH_TE}/dist/*.whl)" >> /opt/pip-tools.d/manifest.te +RUN mkdir -p /opt/pip-tools.d +RUN <> /opt/pip-tools.d/requirements-jax.in +echo "$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/requirements-jax.in +echo "flax" >> /opt/pip-tools.d/requirements-jax.in EOF # TODO: properly configure entrypoint From 2ccf1a9365494328274c27e6a3241c8f8bf94c8f Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Mon, 23 Oct 2023 08:41:37 +0000 Subject: [PATCH 02/29] parent abb6f9704b8e32ed331ac56bd767338796224c3c author Yu-Hang Tang 1698050497 +0000 committer Terry Kong 1701417045 -0800 pip-compile changes Updated t5-large perf (#342) Update Pax README and sub file (#345) - Adds FP8 documentation - Updates perf table - Makes some other minor improvements for readability Adds CUDA_MODULE_LOADING=EAGER to core jax container env vars (#329) Re-enable NVLS in nightly containers (#331) NVLS was disabled due to a known issue in NCCL 2.17 that caused intermittent hangs. The issue has been resolved in NCCL 2.18, so we are safe to re-enable NVLS. --------- Co-authored-by: Terry Kong Update Pax TE patch to point to rebased branch (#348) Loosens t5x loss tests relative tolerances (#343) Relaxing the relative tolerance on the loss tests since it was leading to too many false positives. For reference, deviation in loss for the t5 model can sometimes be up to 15% at the start of training with real data. Adds rosetta-t5x TE + no-TE tests that enable the correct configs for testing (#332) - [ ] Add capability to retroactively test with newer test-t5x.sh like in [t5x-wget-test](https://github.com/NVIDIA/JAX-Toolbox/tree/t5x-wget-test) - [ ] Sets `ENABLE_TE=1` in the Dockerfile.t5x which is identical to the logic from before where it was always enabled in rosetta-t5x Fix markdown hyperlink for jax package on frontpage readme (#319) Adds a --seed option to test-t5x.sh to ensure determinism (#344) To ensure that the tests results for a particular container are reproducible between runs, this change introduces a seed argument that sets the jax seed and dataset seed to 42. It remains configurable, but now there shouldn't be variance given the same container. - Also fixes a typo where --steps-per-epoch wasn't in the usage doc of this script Co-authored-by: NVIDIA Co-authored-by: Yu-Hang "Maxin" Tang Dynamic workflow run names (#356) This change introduces the dynamic [run name field](https://github.blog/changelog/2022-09-26-github-actions-dynamic-names-for-workflow-runs/#:~:text=GitHub%20Actions%20customers%20can%20now,visit%20the%20GitHub%20Actions%20community.) `run-name`. It's currently difficult on mobile to find the "workflow_run" that corresponds to a particular date, so hopefully this helps identify which builds were nightly vs which builds were manually triggered. I couldn't find a good way to dynamically look up the `name` field, so for now I copied all of names. I also wasn't able to find a "created_at" for the scheduled workflows, so those don't have timestamps for now. __Assumptions__: * "workflow_run" == nightly since "scheduled" events only happen on `main` and `workflow_run` are only run for concrete workflows and not reusable workflows - [x] Test the workflow_run codepath - [x] Test the scheduled codepath ![image](https://github.com/NVIDIA/JAX-Toolbox/assets/7576060/4b916452-334a-4a73-9220-9fbadc70462f) Fix random failling tests for backend_independent on V100 (#351) Fixes randomly failures in the backend-independent section of JAX unit tests: ``` Cannot find a free accelerator to run the test on, exiting with failure ``` Changes: limit the number of concurrent test jobs even for backend-independent tests, which do create GPU contexts. As a clarification, `--jobs` and `--local_test_jobs` do not make a difference for our particular CI pipeline, since JAX is built in a separate CI job anyway. References (From Reed Wanderman-Milne @ Google): > 1. In particular, you have to set NB_GPUS, JOBS_PER_ACC, and J correctly or you can get that error (I recently got the same error by not setting those correctly) > 2. (also I think --jobs should be --local_test_jobs in that code block, no reason to restrict the number of jobs compiling JAX) Propagate error code in ViT tests (#357) Merges rosetta unit tests and takes off the marker which spun up another matrix job (#360) This should simplify the rosetta tests and save some time since another matrix job was started for one test Propagate build failures (#363) Always run the `publish-build` step, regardless of whether the rosetta pax/t5x build was attempted. This ensures that badges correctly reflect build failures due to dependent builds failing. Patch for JAX core container (ARM64) (#367) Add patch to XLA to be able to build JAX core container for ARM64 Update the doc for USE_FP8 (#349) This PR provides guidance on how to use the new configuration option, `USE_FP8`, to enable native FP8 support on Hopper GPUs. Update the native-fp8 guide with cudnn layer norm (#368) This PR updates the guide to include the new flag to enable the cudnn layer norm. cc. @ashors1 @terrykong @nouiz Add WAR for XLA NCCL bug causing OOMs (#362) A stopgap for https://github.com/NVIDIA/JAX-Toolbox/issues/346 fix TE multi-device test fix lzma build issue edit TE test name fix TE arm64 test install error remove --install option from get-source.sh fix TE arm64 test install error disable sandbox i'm jet-lagged use Pax image for TE testing Fix job dependency --- .github/container/Dockerfile.jax | 43 ++++++++++++++++++++---- .github/workflows/_sandbox.yaml | 8 ++++- .github/workflows/nightly-t5x-build.yaml | 1 + rosetta/rosetta/projects/t5x/README.md | 1 - rosetta/rosetta/projects/vit/README.md | 1 - 5 files changed, 45 insertions(+), 9 deletions(-) diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index f0f2759e8..ee9a1dbe1 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -37,6 +37,17 @@ RUN --mount=type=ssh \ --mount=type=secret,id=SSH_KNOWN_HOSTS,target=/root/.ssh/known_hosts \ git clone "${REPO_XLA}" "${SRC_PATH_XLA}" && cd "${SRC_PATH_XLA}" && git checkout ${REF_XLA} +# TODO: This is a WAR to NCCL errors we observe in TOT. Should be removed when no longer needed +RUN <> /opt/pip-tools.d/requirements-jax.in -echo "$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/requirements-jax.in -echo "flax" >> /opt/pip-tools.d/requirements-jax.in +RUN <<"EOF" bash -ex +echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/manifest.jax +echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/manifest.jax +EOF + +## Flax +ARG REPO_FLAX +ARG REF_FLAX +ARG SRC_PATH_FLAX +RUN get-source.sh -f ${REPO_FLAX} -r ${REF_FLAX} -d ${SRC_PATH_FLAX} -m /opt/pip-tools.d/manifest.flax + +## Transformer engine: check out source and build wheel +ARG REPO_TE +ARG REF_TE +ARG SRC_PATH_TE +ENV NVTE_FRAMEWORK=jax +ENV SRC_PATH_TE=${SRC_PATH_TE} +RUN <<"EOF" bash -ex +set -o pipefail +pip install ninja && rm -rf ~/.cache/pip +get-source.sh -f ${REPO_TE} -r ${REF_TE} -d ${SRC_PATH_TE} +pushd ${SRC_PATH_TE} +python setup.py bdist_wheel && rm -rf build +echo "transformer-engine @ file://$(ls ${SRC_PATH_TE}/dist/*.whl)" >> /opt/pip-tools.d/manifest.te EOF # TODO: properly configure entrypoint diff --git a/.github/workflows/_sandbox.yaml b/.github/workflows/_sandbox.yaml index 37fa6ca68..cc2adc056 100644 --- a/.github/workflows/_sandbox.yaml +++ b/.github/workflows/_sandbox.yaml @@ -1,7 +1,13 @@ name: "~Sandbox" on: - workflow_dispatch: + # workflow_dispatch: + # push: + +permissions: + contents: read # to fetch code + actions: write # to cancel previous workflows + packages: write # to upload container jobs: sandbox: diff --git a/.github/workflows/nightly-t5x-build.yaml b/.github/workflows/nightly-t5x-build.yaml index a503111b4..48f62bc43 100644 --- a/.github/workflows/nightly-t5x-build.yaml +++ b/.github/workflows/nightly-t5x-build.yaml @@ -71,6 +71,7 @@ jobs: runs-on: ubuntu-22.04 outputs: DOCKER_TAG_MEALKIT: '' + DOCKER_TAG_FINAL: '' steps: - name: Generate placeholder warning shell: bash -x -e {0} diff --git a/rosetta/rosetta/projects/t5x/README.md b/rosetta/rosetta/projects/t5x/README.md index aeec2f688..39401f415 100644 --- a/rosetta/rosetta/projects/t5x/README.md +++ b/rosetta/rosetta/projects/t5x/README.md @@ -197,7 +197,6 @@ t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh \ # Known Issues * There is a known sporadic NCCL crash that happens when using the T5x container at node counts greater than or equal to 32 nodes. We will fix this in the next release. The issue is tracked [here](https://github.com/NVIDIA/JAX-Toolbox/issues/194). -* The T5x nightlies disable `NCCL_NVLS_ENABLE=0` ([doc](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. # Changelog - Added Transformer Engine + FP8 support diff --git a/rosetta/rosetta/projects/vit/README.md b/rosetta/rosetta/projects/vit/README.md index 0c5b22a47..a57896480 100644 --- a/rosetta/rosetta/projects/vit/README.md +++ b/rosetta/rosetta/projects/vit/README.md @@ -157,5 +157,4 @@ Pre-training was performed on 1 node with a global batch size of 4096. Models we ## Known Issues 1. By default, gradient accumulation (GA) sums loss across the microbatches. As a result, loss is scaled up when using gradient accumulation, and training with GA only works when using a scale-invariant optimizer such as Adam or Adafactor. ViT fine-tuning is performed using SGD; thus, GA should not be used when fine-tuning. -2. The nightlies disable `NCCL_NVLS_ENABLE=0` ([doc](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. From 404d6291252493f40f68f4fe8c94e147dc6c2b9d Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 31 Oct 2023 09:37:44 -0700 Subject: [PATCH 03/29] Adds support for building rosetta with local patches and an already generated patch dir comment Add steps to archive patches in run Date the patches for readability Better log msg switch to --3way since that produces a merge conflict to help understand the conflict Switch to mealkit+finalize mechanic for rosetta builds Add github.run_id to artifacts for provenance Update all rosetta workflows with mealkit/final mechanism --- .github/workflows/_build_pax.yaml | 6 +- .github/workflows/_build_rosetta.yaml | 76 ++++++++++-- .github/workflows/_build_t5x.yaml | 6 +- .github/workflows/_ci.yaml | 2 - .../workflows/nightly-rosetta-pax-build.yaml | 86 +++++++++----- .../nightly-rosetta-t5x-build-test.yaml | 109 +++++++++++------- rosetta/Dockerfile.pax | 3 +- rosetta/Dockerfile.t5x | 6 +- rosetta/README.md | 17 ++- rosetta/create-distribution.sh | 68 +++++++++-- rosetta/scripts/extract-patches.sh | 18 +++ 11 files changed, 288 insertions(+), 109 deletions(-) create mode 100755 rosetta/scripts/extract-patches.sh diff --git a/.github/workflows/_build_pax.yaml b/.github/workflows/_build_pax.yaml index 62bc175ad..de2b3dafd 100644 --- a/.github/workflows/_build_pax.yaml +++ b/.github/workflows/_build_pax.yaml @@ -163,7 +163,7 @@ jobs: # bring in utility functions source .github/workflows/scripts/to_json.sh - badge_label='PAX ${{ inputs.ARCHITECTURE }} build' + badge_label='Upstream PAX ${{ inputs.ARCHITECTURE }} build' tags="${{ steps.final-metadata.outputs.tags }}" digest="${{ steps.final-build.outputs.digest }}" outcome="${{ steps.final-build.outcome }}" @@ -171,11 +171,11 @@ jobs: if [[ ${outcome} == "success" ]]; then badge_message="pass" badge_color=brightgreen - summary="PAX build on ${{ inputs.ARCHITECTURE }}: $badge_message" + summary="Upstream PAX build on ${{ inputs.ARCHITECTURE }}: $badge_message" else badge_message="fail" badge_color=red - summary="PAX build on ${{ inputs.ARCHITECTURE }}: $badge_message" + summary="Upstream PAX build on ${{ inputs.ARCHITECTURE }}: $badge_message" fi to_json \ diff --git a/.github/workflows/_build_rosetta.yaml b/.github/workflows/_build_rosetta.yaml index b62bd20ba..ce0cda580 100644 --- a/.github/workflows/_build_rosetta.yaml +++ b/.github/workflows/_build_rosetta.yaml @@ -21,9 +21,14 @@ on: description: 'Build date in YYYY-MM-DD format' required: false default: 'NOT SPECIFIED' + ARTIFACT_NAME: + type: string + description: 'Name of the artifact zip file' + required: false + default: 'artifact-rosetta-build' BADGE_FILENAME: type: string - description: 'Name of the endpoint JSON file for shields.io badge' + description: 'Name of the endpoint JSON file for shields.io badge (w/o .json || arch || library)' required: false default: 'badge-rosetta-build' outputs: @@ -48,7 +53,8 @@ jobs: build-rosetta: runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", small] env: - BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME}}-${{ inputs.ARCHITECTURE}}.json + BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME }}-${{ inputs.BASE_LIBRARY }}-${{ inputs.ARCHITECTURE }}.json + ARTIFACT_NAME_FULL: ${{ inputs.ARTIFACT_NAME}}-${{ inputs.BASE_LIBRARY }}-${{ inputs.ARCHITECTURE }} outputs: DOCKER_TAG_MEALKIT: ${{ steps.mealkit-metadata.outputs.tags }} DOCKER_TAG_FINAL: ${{ steps.final-metadata.outputs.tags }} @@ -108,12 +114,17 @@ jobs: id: final-metadata uses: docker/metadata-action@v4 with: - images: ${{ env.UPLD_IMAGE }} - flavor: latest=false - tags: type=raw,value=${{ github.run_id }}-${{ inputs.BASE_LIBRARY }}-${{ inputs.ARCHITECTURE }}-final - labels: org.opencontainers.image.created=${{ inputs.BUILD_DATE }} + images: | + ${{ env.UPLD_IMAGE }} + flavor: | + latest=false + tags: | + type=raw,value=${{ github.run_id }}-${{ inputs.BASE_LIBRARY }}-${{ inputs.ARCHITECTURE }} + labels: + org.opencontainers.image.created=${{ inputs.BUILD_DATE }} - - name: Build docker images - final + - name: Build final image + id: final-build uses: docker/build-push-action@v4 with: context: rosetta/ @@ -125,3 +136,54 @@ jobs: target: final build-args: | BASE_IMAGE=${{ steps.defaults.outputs.BASE_IMAGE }} + + - name: Extract patches + run: rosetta/scripts/extract-patches.sh ${{ steps.final-metadata.outputs.tags }} + + - name: Archive generated patches + uses: actions/upload-artifact@v3 + with: + name: patches-${{ inputs.BASE_LIBRARY }}-${{ inputs.BUILD_DATE }}-${{ github.run_id }}-${{ inputs.ARCHITECTURE }} + path: rosetta/patches + + - name: Generate sitrep + if: success() || failure() + shell: bash -x -e {0} + run: | + # bring in utility functions + source .github/workflows/scripts/to_json.sh + + badge_label='${{ inputs.BASE_LIBRARY }} ${{ inputs.ARCHITECTURE }} build' + tags="${{ steps.final-metadata.outputs.tags }}" + digest="${{ steps.final-build.outputs.digest }}" + outcome="${{ steps.final-build.outcome }}" + + if [[ ${outcome} == "success" ]]; then + badge_message="pass" + badge_color=brightgreen + summary="${{ inputs.BASE_LIBRARY }} build on ${{ inputs.ARCHITECTURE }}: $badge_message" + else + badge_message="fail" + badge_color=red + summary="${{ inputs.BASE_LIBRARY }} build on ${{ inputs.ARCHITECTURE }}: $badge_message" + fi + + to_json \ + summary \ + badge_label tags digest outcome \ + > sitrep.json + + schemaVersion=1 \ + label="${badge_label}" \ + message="${badge_message}" \ + color="${badge_color}" \ + to_json schemaVersion label message color \ + > ${{ env.BADGE_FILENAME_FULL }} + + - name: Upload sitrep and badge + uses: actions/upload-artifact@v3 + with: + name: ${{ env.ARTIFACT_NAME_FULL }} + path: | + sitrep.json + ${{ env.BADGE_FILENAME_FULL }} diff --git a/.github/workflows/_build_t5x.yaml b/.github/workflows/_build_t5x.yaml index 6efcde8f9..890a47141 100644 --- a/.github/workflows/_build_t5x.yaml +++ b/.github/workflows/_build_t5x.yaml @@ -163,7 +163,7 @@ jobs: # bring in utility functions source .github/workflows/scripts/to_json.sh - badge_label='T5X ${{ inputs.ARCHITECTURE }} build' + badge_label='Upstream T5X ${{ inputs.ARCHITECTURE }} build' tags="${{ steps.final-metadata.outputs.tags }}" digest="${{ steps.final-build.outputs.digest }}" outcome="${{ steps.final-build.outcome }}" @@ -171,11 +171,11 @@ jobs: if [[ ${outcome} == "success" ]]; then badge_message="pass" badge_color=brightgreen - summary="T5X build on ${{ inputs.ARCHITECTURE }}: $badge_message" + summary="Upstream T5X build on ${{ inputs.ARCHITECTURE }}: $badge_message" else badge_message="fail" badge_color=red - summary="T5X build on ${{ inputs.ARCHITECTURE }}: $badge_message" + summary="Upstream T5X build on ${{ inputs.ARCHITECTURE }}: $badge_message" fi to_json \ diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 6ddbe7e69..9ff53847c 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -194,7 +194,6 @@ jobs: uses: ./.github/workflows/_test_t5x_rosetta.yaml with: T5X_IMAGE: ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAG_FINAL }} - # Disable packing b/c rosetta-t5x images run with TE by default, and TE does not currently support packing secrets: inherit test-upstream-pax: @@ -212,4 +211,3 @@ jobs: with: PAX_IMAGE: ${{ needs.build-rosetta-pax.outputs.DOCKER_TAG_FINAL }} secrets: inherit - diff --git a/.github/workflows/nightly-rosetta-pax-build.yaml b/.github/workflows/nightly-rosetta-pax-build.yaml index 95deaa288..164f62438 100644 --- a/.github/workflows/nightly-rosetta-pax-build.yaml +++ b/.github/workflows/nightly-rosetta-pax-build.yaml @@ -54,17 +54,8 @@ jobs: "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/cancel" cat # blocks execution in case workflow cancellation takes time - - name: Determine if the resulting container should be 'published' - id: if-publish - shell: bash -x -e {0} - run: - # A container should be published if: - # 1) the workflow is triggered by workflow_dispatch and the PUBLISH input is true, or - # 2) the workflow is triggered by workflow_run (i.e., a nightly build) - echo "PUBLISH=${{ github.event_name == 'workflow_run' || (github.event_name == 'workflow_dispatch' && inputs.PUBLISH) }}" >> $GITHUB_OUTPUT - - - name: Set build date - id: date + - name: Set build metadata + id: meta-vars shell: bash -x -e {0} run: | BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d') @@ -84,6 +75,10 @@ jobs: echo "BASE_LIBRARY=${{ env.BASE_LIBRARY }}" >> $GITHUB_OUTPUT echo "BASE_IMAGE_AMD64=${BASE_IMAGE_AMD64}" >> $GITHUB_OUTPUT echo "BASE_IMAGE_ARM64=${BASE_IMAGE_ARM64}" >> $GITHUB_OUTPUT + # A container should be published if: + # 1) the workflow is triggered by workflow_dispatch and the PUBLISH input is true, or + # 2) the workflow is triggered by workflow_run (i.e., a nightly build) + echo "PUBLISH=${{ github.event_name == 'workflow_run' || (github.event_name == 'workflow_dispatch' && inputs.PUBLISH) }}" >> $GITHUB_OUTPUT amd64: needs: metadata @@ -105,6 +100,28 @@ jobs: BASE_IMAGE: ${{ needs.metadata.outputs.BASE_IMAGE_ARM64 }} secrets: inherit + publish-build-badge: + needs: [metadata, amd64, arm64] + uses: ./.github/workflows/_publish_badge.yaml + if: always() + with: + ENDPOINT_FILENAME: 'rosetta-pax-build-status.json' + PUBLISH: ${{ needs.metadata.outputs.PUBLISH == 'true' }} + SCRIPT: | + if [[ ${{ needs.amd64.result }} == "success" && ${{ needs.arm64.result }} == "success" ]]; then + BADGE_COLOR=brightgreen + MSG=passing + STATUS=success + else + BADGE_COLOR=red + MSG=failing + STATUS=failure + fi + echo "LABEL='nightly'" >> $GITHUB_OUTPUT + echo "MESSAGE='${MSG}'" >> $GITHUB_OUTPUT + echo "COLOR='${BADGE_COLOR}'" >> $GITHUB_OUTPUT + echo "STATUS='${STATUS}'" >> ${GITHUB_OUTPUT} + publish-mealkit: needs: [metadata, amd64, arm64] if: needs.metadata.outputs.PUBLISH == 'true' @@ -118,8 +135,17 @@ jobs: type=raw,value=mealkit,priority=500 type=raw,value=mealkit-${{ needs.metadata.outputs.BUILD_DATE }},priority=500 + # TODO: Test ARM when runners available + test-amd64: + needs: amd64 + uses: ./.github/workflows/_test_pax_rosetta.yaml + with: + PAX_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} + secrets: inherit + + # TODO: ARM Tests publish-final: - needs: [metadata, amd64, arm64] + needs: [metadata, amd64, arm64, test-amd64] if: needs.metadata.outputs.PUBLISH == 'true' uses: ./.github/workflows/_publish_container.yaml with: @@ -128,19 +154,11 @@ jobs: ${{ needs.arm64.outputs.DOCKER_TAG_FINAL }} TARGET_IMAGE: pax TARGET_TAGS: | - type=raw,value=latest,priority=1000 - type=raw,value=nightly-${{ needs.metadata.outputs.BUILD_DATE }},priority=900 - - test-pax: - needs: [metadata, amd64, arm64] - uses: ./.github/workflows/_test_pax_rosetta.yaml - if: (github.event_name == 'workflow_run' && github.event.workflow_run.conclusion == 'success') || github.event_name == 'workflow_dispatch' - with: - PAX_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} - secrets: inherit + ${{ needs.test-amd64.outputs.TEST_STATUS == 'success' && 'type=raw,value=latest,priority=1000' || '' }} + type=raw,value=nightly-${{ needs.metadata.outputs.BUILD_DATE }},priority=900 publish-pax: - needs: [metadata, test-pax] + needs: [metadata, test-amd64] uses: ./.github/workflows/_publish_t5x_pax_results.yaml if: (github.event_name == 'workflow_run' && github.event.workflow_run.conclusion == 'success') || github.event_name == 'workflow_dispatch' with: @@ -149,38 +167,42 @@ jobs: ARTIFACT_NAME: "rosetta-pax-" secrets: inherit - publish-test: - needs: [metadata, amd64, arm64, test-pax] + # TODO: ARM + publish-test-badge: + needs: [metadata, amd64, test-amd64] uses: ./.github/workflows/_publish_badge.yaml - if: ( always() ) + if: always() secrets: inherit with: ENDPOINT_FILENAME: 'rosetta-pax-overall-test-status.json' - PUBLISH: ${{ github.event_name == 'workflow_run' || needs.metadata.outputs.PUBLISH == 'true' }} + PUBLISH: ${{ needs.metadata.outputs.PUBLISH == 'true' }} SCRIPT: | - PAX_STATUS=${{ needs.test-pax.outputs.TEST_STATUS }} + PAX_STATUS=${{ needs.test-amd64.outputs.TEST_STATUS }} echo "LABEL='Tests'" >> $GITHUB_OUTPUT - if [[ ${{ needs.amd64.result }} == "success" && ${{ needs.arm64.result }} == "success" ]]; then + STATUS=failure + if [[ ${{ needs.amd64.result }} == "success" ]]; then if [[ $PAX_STATUS == "success" ]]; then COLOR=brightgreen MESSAGE="MGMN passed" + STATUS=success else COLOR=red MESSAGE="MGMN failed" fi else - MESSAGE="n/a" COLOR="red" + MESSAGE="n/a" fi - echo "MESSAGE='${MESSAGE}'" >> $GITHUB_OUTPUT echo "COLOR='${COLOR}'" >> $GITHUB_OUTPUT + echo "MESSAGE='${MESSAGE}'" >> $GITHUB_OUTPUT + echo "STATUS='${STATUS}'" >> ${GITHUB_OUTPUT} finalize: if: "!cancelled()" - needs: [metadata, amd64, arm64] + needs: [metadata, amd64, arm64, test-amd64] uses: ./.github/workflows/_finalize.yaml with: PUBLISH_BADGE: ${{ needs.metadata.outputs.PUBLISH == 'true' }} diff --git a/.github/workflows/nightly-rosetta-t5x-build-test.yaml b/.github/workflows/nightly-rosetta-t5x-build-test.yaml index 4e580f465..70adfb1f0 100644 --- a/.github/workflows/nightly-rosetta-t5x-build-test.yaml +++ b/.github/workflows/nightly-rosetta-t5x-build-test.yaml @@ -54,15 +54,6 @@ jobs: "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/cancel" cat # blocks execution in case workflow cancellation takes time - - name: Determine if the resulting container should be 'published' - id: if-publish - shell: bash -x -e {0} - run: - # A container should be published if: - # 1) the workflow is triggered by workflow_dispatch and the PUBLISH input is true, or - # 2) the workflow is triggered by workflow_run (i.e., a nightly build) - echo "PUBLISH=${{ github.event_name == 'workflow_run' || (github.event_name == 'workflow_dispatch' && inputs.PUBLISH) }}" >> $GITHUB_OUTPUT - - name: Set build date id: date shell: bash -x -e {0} @@ -84,7 +75,11 @@ jobs: echo "BASE_LIBRARY=${{ env.BASE_LIBRARY }}" >> $GITHUB_OUTPUT echo "BASE_IMAGE_AMD64=${BASE_IMAGE_AMD64}" >> $GITHUB_OUTPUT echo "BASE_IMAGE_ARM64=${BASE_IMAGE_ARM64}" >> $GITHUB_OUTPUT - + # A container should be published if: + # 1) the workflow is triggered by workflow_dispatch and the PUBLISH input is true, or + # 2) the workflow is triggered by workflow_run (i.e., a nightly build) + echo "PUBLISH=${{ github.event_name == 'workflow_run' || (github.event_name == 'workflow_dispatch' && inputs.PUBLISH) }}" >> $GITHUB_OUTPUT + amd64: needs: metadata uses: ./.github/workflows/_build_rosetta.yaml @@ -95,20 +90,46 @@ jobs: BASE_IMAGE: ${{ needs.metadata.outputs.BASE_IMAGE_AMD64 }} secrets: inherit + # TODO: Can't build ARM until https://github.com/NVIDIA/JAX-Toolbox/pull/252 is available arm64: needs: metadata runs-on: ubuntu-22.04 outputs: + DOCKER_TAG_FINAL: '' DOCKER_TAG_MEALKIT: '' steps: - name: Generate placeholder warning shell: bash -x -e {0} run: | echo "WARNING: arm64 build is not yet supported" - + + # TODO: ARM + publish-build-badge: + needs: [metadata, amd64, arm64] + uses: ./.github/workflows/_publish_badge.yaml + if: always() + secrets: inherit + with: + ENDPOINT_FILENAME: 'rosetta-t5x-build-status.json' + PUBLISH: ${{ needs.metadata.outputs.PUBLISH == 'true' }} + SCRIPT: | + if [[ ${{ needs.amd64.result }} == "success" && ${{ needs.arm64.result }} == "success" ]]; then + BADGE_COLOR=brightgreen + MSG=passing + STATUS=success + else + BADGE_COLOR=red + MSG=failing + STATUS=failure + fi + echo "LABEL='nightly'" >> $GITHUB_OUTPUT + echo "MESSAGE='${MSG}'" >> $GITHUB_OUTPUT + echo "COLOR='${BADGE_COLOR}'" >> $GITHUB_OUTPUT + echo "STATUS='${STATUS}'" >> ${GITHUB_OUTPUT} + publish-mealkit: needs: [metadata, amd64, arm64] - if: needs.metadata.output.PUBLISH == 'true' + if: needs.metadata.outputs.PUBLISH == 'true' uses: ./.github/workflows/_publish_container.yaml with: SOURCE_IMAGE: | @@ -117,65 +138,65 @@ jobs: TARGET_IMAGE: t5x TARGET_TAGS: | type=raw,value=mealkit,priority=500 - type=raw,value=mealkit-${{ needs.metadata.outputs.BUILD_DATE }},priority=500 + type=raw,value=mealkit-${{ needs.metadata.outputs.BUILD_DATE }},priority=500 - publish-final: - needs: [metadata, amd64, arm64] - if: needs.metadata.outputs.PUBLISH == 'true' - uses: ./.github/workflows/_publish_container.yaml - with: - SOURCE_IMAGE: | - ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} - ${{ needs.arm64.outputs.DOCKER_TAG_FINAL }} - TARGET_IMAGE: t5x - TARGET_TAGS: | - type=raw,value=latest,priority=1000 - type=raw,value=nightly-${{ needs.metadata.outputs.BUILD_DATE }},priority=900 - - test-unit: - if: (github.event_name == 'workflow_run' && github.event.workflow_run.conclusion == 'success') || github.event_name == 'workflow_dispatch' - needs: [metadata, amd64, arm64] + test-unit-amd64: + needs: amd64 uses: ./.github/workflows/_test_rosetta.yaml with: ROSETTA_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} secrets: inherit - test-t5x: - needs: [metadata, amd64, arm64] + test-t5x-amd64: + needs: amd64 uses: ./.github/workflows/_test_t5x_rosetta.yaml - if: (github.event_name == 'workflow_run' && github.event.workflow_run.conclusion == 'success') || github.event_name == 'workflow_dispatch' with: T5X_IMAGE: ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} secrets: inherit + publish-final: + needs: [metadata, amd64, arm64, test-t5x-amd64, test-unit-amd64] + if: needs.metadata.outputs.PUBLISH == 'true' + uses: ./.github/workflows/_publish_container.yaml + with: + SOURCE_IMAGE: | + ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} + ${{ needs.arm64.outputs.DOCKER_TAG_FINAL }} + TARGET_IMAGE: t5x + TARGET_TAGS: | + ${{ ( needs.test-t5x-amd64.outputs.TEST_STATUS == 'success' && needs.test-unit-amd64.outputs.TEST_STATUS == 'success' ) && 'type=raw,value=latest,priority=1000' || '' }} + type=raw,value=nightly-${{ needs.metadata.outputs.BUILD_DATE }},priority=900 + publish-t5x: - needs: [metadata, test-t5x] + needs: [metadata, test-t5x-amd64] uses: ./.github/workflows/_publish_t5x_pax_results.yaml - if: (github.event_name == 'workflow_run' && github.event.workflow_run.conclusion == 'success') || github.event_name == 'workflow_dispatch' with: BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} EXPERIMENT_SUBDIR: ROSETTA_T5X ARTIFACT_NAME: "rosetta-T5X-" secrets: inherit - publish-test: - needs: [metadata, amd64, arm64, test-unit, test-t5x] + # TODO: ARM + publish-test-badge: + needs: [metadata, amd64, test-unit-amd64, test-t5x-amd64] uses: ./.github/workflows/_publish_badge.yaml - if: ( always() ) + if: always() secrets: inherit with: ENDPOINT_FILENAME: 'rosetta-t5x-overall-test-status.json' - PUBLISH: ${{ github.event_name == 'workflow_run' || needs.metadata.outputs.PUBLISH == 'true' }} + PUBLISH: ${{ needs.metadata.outputs.PUBLISH == 'true' }} SCRIPT: | - UNIT_STATUS=${{ needs.test-unit.outputs.TEST_STATUS }} - T5X_STATUS=${{ needs.test-t5x.outputs.TEST_STATUS }} + UNIT_STATUS=${{ needs.test-unit-amd64.outputs.TEST_STATUS }} + T5X_STATUS=${{ needs.test-t5x-amd64.outputs.TEST_STATUS }} echo "LABEL='Tests'" >> $GITHUB_OUTPUT - if [[ ${{ needs.amd64.result }} == "success" && ${{ needs.arm64.result }} == "success" ]]; then + STATUS=failure + if [[ ${{ needs.amd64.result }} == "success" ]]; then if [[ $UNIT_STATUS == "success" ]] && [[ $T5X_STATUS == "success" ]]; then COLOR=brightgreen MESSAGE="Unit passed / MGMN passed" + STATUS=success elif [[ $UNIT_STATUS == "success" ]]; then COLOR=yellow MESSAGE="Unit passed / MGMN failed" @@ -187,18 +208,18 @@ jobs: MESSAGE="Unit failed / MGMN failed" fi else - MESSAGE="n/a" COLOR="red" + MESSAGE="n/a" fi echo "MESSAGE='${MESSAGE}'" >> $GITHUB_OUTPUT echo "COLOR='${COLOR}'" >> $GITHUB_OUTPUT + echo "STATUS='${STATUS}'" >> ${GITHUB_OUTPUT} finalize: if: "!cancelled()" - needs: [metadata, amd64, arm64] + needs: [metadata, amd64, arm64, test-t5x-amd64, test-unit-amd64] uses: ./.github/workflows/_finalize.yaml with: PUBLISH_BADGE: ${{ needs.metadata.outputs.PUBLISH == 'true' }} secrets: inherit - diff --git a/rosetta/Dockerfile.pax b/rosetta/Dockerfile.pax index b03e1ffd8..0ba10ae75 100644 --- a/rosetta/Dockerfile.pax +++ b/rosetta/Dockerfile.pax @@ -35,7 +35,7 @@ ARG PRAXIS_PATCHLIST COPY --from=rosetta-source / /opt/rosetta WORKDIR /opt/rosetta RUN --mount=target=/opt/pax-mirror,from=pax-mirror-source,readwrite \ - --mount=target=/opt/praxis-mirror,from=praxis-mirror-source,readwrite <> /opt/pip-tools.d/manifest.t5x echo "-e file:///opt/rosetta" >> /opt/pip-tools.d/manifest.t5x EOF diff --git a/rosetta/README.md b/rosetta/README.md index 7e0baa98e..c91a4ed3f 100644 --- a/rosetta/README.md +++ b/rosetta/README.md @@ -8,13 +8,20 @@ LLM, CV, and multimodal models. ```bash ROSETTA_BASE=t5x # or pax -docker buildx build --target rosetta --tag rosetta:latest -f Dockerfile.${ROSETTA_BASE} . - -# If you want a devel image with test dependencies -docker buildx build --target rosetta-devel --tag rosetta-devel:latest -f Dockerfile.${ROSETTA_BASE} . +docker buildx build --tag rosetta:latest -f Dockerfile.${ROSETTA_BASE} . # If you want to specify a specific base image -docker buildx build --target rosetta --tag rosetta:latest -f Dockerfile.${ROSETTA_BASE} --build-arg BASE_IMAGE=ghcr.io/nvidia/${ROSETTA_BASE}:nightly-2023-05-01 . +docker buildx build --tag rosetta:latest -f Dockerfile.${ROSETTA_BASE} --build-arg BASE_IMAGE=ghcr.io/nvidia/${ROSETTA_BASE}:mealkit-YYYY-MM-DD . +``` + +### Advanced use-cases +```sh +# [T5x Example] If you want to build with a different patchlist (patchlist must be relative to rosetta dir) +docker buildx build --build-arg T5X_PATCHLIST=patches/t5x/patchlist-t5x.txt.gen --build-arg FLAX_PATCHLIST=patches/flax/patchlist-flax.txt.gen --target rosetta --tag rosetta:latest -f Dockerfile.t5x . + +# [T5x Example] If you want to build with patches from another image +scripts/extract-patches.sh # Extracts generated patch dir under ./patches/ +docker buildx build --build-arg T5X_PATCHLIST=patches/t5x/patchlist-t5x.txt.gen --build-arg FLAX_PATCHLIST=patches/flax/patchlist-flax.txt.gen --target rosetta --tag rosetta:latest -f Dockerfile.t5x . ``` ## Development diff --git a/rosetta/create-distribution.sh b/rosetta/create-distribution.sh index 1eb99d111..a4db14ef3 100755 --- a/rosetta/create-distribution.sh +++ b/rosetta/create-distribution.sh @@ -14,6 +14,9 @@ Usage: $0 [OPTION]... -p, --patchlist=PATH Path to patchlist.txt with feature PRs -r, --ref=REF Git commit hash or tag name that specifies the base of the t5x distribution. Defaults to main (not origin/main) +A patchlist will be generated by this script and placed under $SCRIPT_DIR/patches/ with all entries +replaced with local patches. + Relationship between --dir, --extra-dir, and --mirror-url repo args: --dir: The upstream repo, locally cloned --mirror-url: A mirror of the upstream repo @@ -21,6 +24,8 @@ Relationship between --dir, --extra-dir, and --mirror-url repo args: Patches in the --patchlist will be applied from the repos above according to the following rules: + Local patches (relative to this file): + * ^file://.* --dir: * ^pull/.* --mirror-url: @@ -84,6 +89,13 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) INSTALLED_DIR=${INSTALLED_DIR:-/opt/t5x} DISTRIBUTION_BASE_REF=${DISTRIBUTION_BASE_REF:-HEAD} MIRROR_GIT_URL=${MIRROR_GIT_URL:-https://github.com/nvjax-svc-0/t5x.git} +GEN_PATCH_DIR=${GEN_PATCH_DIR:-$SCRIPT_DIR/patches/$(basename $(git -C $INSTALLED_DIR remote get-url origin) .git)} +GEN_PATCH_LIST=$GEN_PATCH_DIR/$(basename $PATCH_LIST).gen +if [[ -e $GEN_PATCH_LIST ]]; then + echo "[WARNING]: $GEN_PATCH_LIST exists and will be overwritten" + rm -f $GEN_PATCH_LIST +fi +mkdir -p $GEN_PATCH_DIR if [[ -z "${INSTALLED_DIR}" ]]; then echo "[ERROR]: Need to specify -d/--dir" @@ -177,6 +189,30 @@ fork-point() { merge_commit=$(git rev-list --ancestry-path ${feat_branch}..${main} | tail -n1) git merge-base ${merge_commit}^ ${feat_branch}^ } +# git-am + adds to generated patchlist +am+record() { + # Canonicalize path to remove extra slashes or dot syntax + patch_path=$(readlink -f $1) + if [[ ! $patch_path =~ ^${SCRIPT_DIR} ]]; then + echo "[ERROR]: patch_path=$patch_path should start with $SCRIPT_DIR" + exit 1 + fi + # Apply the patch + git am --3way <$patch_path || ret_code=$? + if [[ ${ret_code:-0} -ne 0 ]]; then + cat <> $GEN_PATCH_LIST +} apply-patches() { from=$1 to=$2 @@ -185,19 +221,22 @@ apply-patches() { if [[ $num_merge_commits -gt 0 ]]; then echo "[WARNING] There are merge commits between ${from}..${to}. Linearizing history before cherry-picking to remove merge-commits" >&2 # Make a tmp branch for the linear history - git checkout -b tmp-linear-tmp $to + to_linear=${to}.linearized + git checkout -b ${to_linear} $to # This will create a linear history git rebase $from # switch back to the rosetta-distribution branch git checkout - - to=tmp-linear-tmp + to=${to_linear} fi - git cherry-pick ${from}..${to} - ret_code=$? - if [[ $to == tmp-linear-tmp ]]; then - git branch -D tmp-linear-tmp + # Make the patch + patch_fname=$(tr '/' '-' <<< "$to").patch + git format-patch --stdout ${from}..${to} >$GEN_PATCH_DIR/$patch_fname + if [[ -n "${to_linear:-}" ]]; then + git branch -D ${to_linear} fi - return $ret_code + # Apply the patch + am+record $GEN_PATCH_DIR/$patch_fname } MIRROR_REMOTE_NAME=mirror if git remote show ${MIRROR_REMOTE_NAME} &>/dev/null; then @@ -210,7 +249,15 @@ for line in $(cat ${PATCH_LIST}); do continue fi git_ref=$(awk '{print $1}' <<< "${line}") - if [[ "${git_ref}" =~ ^pull/ ]]; then + if [[ "${git_ref}" =~ ^file:// ]]; then + patch_path=$SCRIPT_DIR/${git_ref#file://} + if [[ ! -f $patch_path ]]; then + echo "[ERROR]: ${git_ref} refers to $patch_path which does not exist" + exit 1 + fi + am+record $patch_path + continue + elif [[ "${git_ref}" =~ ^pull/ ]]; then REMOTE_NAME=origin PR_ID=$(cut -d/ -f2 <<<"${git_ref}") branch=PR-${PR_ID} @@ -229,7 +276,7 @@ for line in $(cat ${PATCH_LIST}); do main_branch=${REMOTE_NAME}/main else if [[ -z "${EXTRA_DIR+x}" ]] || [[ ! -d ${EXTRA_DIR} ]]; then - echo "[WARNING]: EXTRA_DIR=${EXTRA_DIR} does not exist so cannot cherry-pick ${git_ref}" + echo "[WARNING]: EXTRA_DIR=${EXTRA_DIR} does not exist so cannot apply patch=${git_ref}" continue fi REMOTE_NAME=${EXTRA_REMOTE_NAME} @@ -239,9 +286,8 @@ for line in $(cat ${PATCH_LIST}); do main_branch=${REMOTE_NAME}/main${TMP_BRANCH_SUFFIX} fi fork_point=$(fork-point ${main_branch} ${branch}) - ret_code=0 apply-patches ${fork_point} ${branch} || ret_code=$? - if [[ ${ret_code} -ne 0 ]]; then + if [[ ${ret_code:-0} -ne 0 ]]; then cat < /dev/null && pwd ) +cd $SCRIPT_DIR + +if [[ $# -lt 1 || $# -gt 2 ]]; then + echo "Copies the patches from within an image to the GIT_ROOT/rosetta/patches dir" + echo + echo "Usage: $0 " + exit 1 +fi + +IMAGE=$1 +ROSETTA_DIR=${2:-$(readlink -f ../)} + +container_id=$(docker create $IMAGE) +docker cp $container_id:/opt/rosetta/patches $ROSETTA_DIR +docker rm -v $container_id From 440471085d045746ec34e56049b6f6a5dddf5c72 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 17 Nov 2023 15:50:14 -0800 Subject: [PATCH 04/29] parent 8a43f4a741108ad610fb990a1c10b65dec607200 author Terry Kong 1700265014 -0800 committer Terry Kong 1701417338 -0800 parent 8a43f4a741108ad610fb990a1c10b65dec607200 author Terry Kong 1700265014 -0800 committer Terry Kong 1701417298 -0800 Adds (1) bump.sh which bumps the manifest and pins the patches (2) updates create-distribution.sh to work with manifests (3) move everything to .github/container sandbox fix write add propagation of trial branch to all workflows and update sandbox to test synchronous workflow check wip test wip changes wip wip don't need wip wip remove make trial branch contingent on publishing Update get-source and initial update for jax build to accept manifest update manifest jax build partially working + patches update pax/t5x dockerfiles, add more repos into manifest, and update pip-finalize to use *.in instead of manifest.txt update manifest with rest of repos and patches missing arg fix jax/pax/t5x all builds work now! update manifest file everywhere fix all workflows cleanup get the context right fix all broken tests custom pip distribution works --- .github/container/Dockerfile.base | 14 +- .github/container/Dockerfile.jax | 62 +- .github/container/Dockerfile.pax.amd64 | 18 +- .github/container/Dockerfile.pax.arm64 | 57 +- .github/container/Dockerfile.t5x | 8 +- .github/container/bump.sh | 97 + .../container}/create-distribution.sh | 219 +- .github/container/get-source.sh | 62 +- .github/container/manifest.yaml | 102 + .github/container/patches/flax/PR-3340.patch | 440 ++ .github/container/patches/paxml/PR-46.patch | 1567 ++++++++ .github/container/patches/praxis/PR-27.patch | 38 + .../t5x/mirror-patch-dali-support.patch | 387 ++ ...ror-patch-partial-checkpoint-restore.patch | 29 + ...ror-patch-t5x_te_in_contrib_noindent.patch | 3553 +++++++++++++++++ .github/container/pip-finalize.sh | 35 +- .github/container/pip-vcs-equivalency.patch | 68 + .github/workflows/_build_jax.yaml | 47 +- .github/workflows/_build_pax.yaml | 28 - .github/workflows/_build_rosetta.yaml | 8 +- .github/workflows/_build_t5x.yaml | 28 - .github/workflows/_ci.yaml | 60 +- .github/workflows/_sandbox.yaml | 17 +- .github/workflows/ci.yaml | 71 +- .github/workflows/nightly-jax-build.yaml | 48 +- rosetta/Dockerfile.pax | 82 +- rosetta/Dockerfile.t5x | 85 +- rosetta/README.md | 19 +- rosetta/patchlist-flax.txt | 8 - rosetta/patchlist-paxml.txt | 8 - rosetta/patchlist-praxis.txt | 8 - rosetta/patchlist-t5x.txt | 10 - rosetta/tests/extra-only-distribution.sh | 43 +- rosetta/tests/mirror-only-distribution.sh | 45 +- rosetta/{ => tests}/test-vit.sh | 0 rosetta/tests/upstream-only-distribution.sh | 41 +- 36 files changed, 6857 insertions(+), 555 deletions(-) create mode 100755 .github/container/bump.sh rename {rosetta => .github/container}/create-distribution.sh (59%) create mode 100644 .github/container/manifest.yaml create mode 100644 .github/container/patches/flax/PR-3340.patch create mode 100644 .github/container/patches/paxml/PR-46.patch create mode 100644 .github/container/patches/praxis/PR-27.patch create mode 100644 .github/container/patches/t5x/mirror-patch-dali-support.patch create mode 100644 .github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch create mode 100644 .github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch create mode 100644 .github/container/pip-vcs-equivalency.patch delete mode 100644 rosetta/patchlist-flax.txt delete mode 100644 rosetta/patchlist-paxml.txt delete mode 100644 rosetta/patchlist-praxis.txt delete mode 100644 rosetta/patchlist-t5x.txt rename rosetta/{ => tests}/test-vit.sh (100%) diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index 8660e31de..a769bba06 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -38,12 +38,24 @@ RUN <<"EOF" bash -ex git config --global user.name "${GIT_USER_NAME}" git config --global user.email "${GIT_USER_EMAIL}" EOF -RUN pip install --upgrade --no-cache-dir pip pip-tools && rm -rf ~/.cache/* RUN mkdir -p /opt/pip-tools.d ADD --chmod=777 \ get-source.sh \ pip-finalize.sh \ /usr/local/bin/ +RUN wget https://github.com/mikefarah/yq/releases/latest/download/yq_linux_$(dpkg --print-architecture) -O /usr/local/bin/yq && \ + chmod 777 /usr/local/bin/yq +RUN git clone -b 23.3.1 https://github.com/pypa/pip.git /opt/pip +# Patch is specific to 23.3.1 +# Generated via: "git diff > pip-vcs-equivalency.patch" +ADD pip-vcs-equivalency.patch /opt/pip/ +RUN <> /opt/pip-tools.d/manifest.jax -echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/manifest.jax +echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in +echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/requirements-jax.in EOF ## Flax -ARG REPO_FLAX -ARG REF_FLAX -ARG SRC_PATH_FLAX -RUN get-source.sh -f ${REPO_FLAX} -r ${REF_FLAX} -d ${SRC_PATH_FLAX} -m /opt/pip-tools.d/manifest.flax +RUN get-source.sh -l flax -m ${MANIFEST_FILE} -o /opt/pip-tools.d/requirements-flax.in ## Transformer engine: check out source and build wheel -ARG REPO_TE -ARG REF_TE -ARG SRC_PATH_TE ENV NVTE_FRAMEWORK=jax ENV SRC_PATH_TE=${SRC_PATH_TE} -RUN <<"EOF" bash -ex -set -o pipefail +RUN <<"EOF" bash -ex -o pipefail pip install ninja && rm -rf ~/.cache/pip -get-source.sh -f ${REPO_TE} -r ${REF_TE} -d ${SRC_PATH_TE} +get-source.sh -l transformer-engine -m ${MANIFEST_FILE} pushd ${SRC_PATH_TE} python setup.py bdist_wheel && rm -rf build -echo "transformer-engine @ file://$(ls ${SRC_PATH_TE}/dist/*.whl)" >> /opt/pip-tools.d/manifest.te +echo "transformer-engine @ file://$(ls ${SRC_PATH_TE}/dist/*.whl)" >> /opt/pip-tools.d/requirements-te.in EOF # TODO: properly configure entrypoint @@ -119,3 +122,4 @@ EOF FROM mealkit as final RUN pip-finalize.sh + diff --git a/.github/container/Dockerfile.pax.amd64 b/.github/container/Dockerfile.pax.amd64 index a56dc4a17..9d0312c21 100644 --- a/.github/container/Dockerfile.pax.amd64 +++ b/.github/container/Dockerfile.pax.amd64 @@ -1,10 +1,6 @@ # syntax=docker/dockerfile:1-labs ARG BASE_IMAGE=ghcr.io/nvidia/jax:mealkit -ARG REPO_PAXML=https://github.com/google/paxml.git -ARG REPO_PRAXIS=https://github.com/google/praxis.git -ARG REF_PAXML=main -ARG REF_PRAXIS=main ARG SRC_PATH_PAXML=/opt/paxml ARG SRC_PATH_PRAXIS=/opt/praxis @@ -13,21 +9,17 @@ ARG SRC_PATH_PRAXIS=/opt/praxis ############################################################################### FROM ${BASE_IMAGE} as mealkit -ARG REPO_PAXML -ARG REPO_PRAXIS -ARG REF_PAXML -ARG REF_PRAXIS ARG SRC_PATH_PAXML ARG SRC_PATH_PRAXIS # update TE manifest file to install the [test] extras -RUN sed -i "s/transformer-engine @/transformer-engine[test] @/g" /opt/pip-tools.d/manifest.te +RUN sed -i "s/transformer-engine @/transformer-engine[test] @/g" /opt/pip-tools.d/requirements-te.in RUN <<"EOF" bash -ex -get-source.sh -f ${REPO_PAXML} -r ${REF_PAXML} -d ${SRC_PATH_PAXML} -get-source.sh -f ${REPO_PRAXIS} -r ${REF_PRAXIS} -d ${SRC_PATH_PRAXIS} -echo "-e file://${SRC_PATH_PAXML}[gpu]" >> /opt/pip-tools.d/manifest.pax -echo "-e file://${SRC_PATH_PRAXIS}" >> /opt/pip-tools.d/manifest.pax +get-source.sh -l paxml -m ${MANIFEST_FILE} +get-source.sh -l praxis -m ${MANIFEST_FILE} +echo "-e file://${SRC_PATH_PAXML}[gpu]" >> /opt/pip-tools.d/requirements-paxml.in +echo "-e file://${SRC_PATH_PRAXIS}" >> /opt/pip-tools.d/requirements-paxml.in for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do pushd ${src} diff --git a/.github/container/Dockerfile.pax.arm64 b/.github/container/Dockerfile.pax.arm64 index cdc714197..5a45c267b 100644 --- a/.github/container/Dockerfile.pax.arm64 +++ b/.github/container/Dockerfile.pax.arm64 @@ -1,12 +1,10 @@ # syntax=docker/dockerfile:1-labs ARG BASE_IMAGE=ghcr.io/nvidia/jax:mealkit -ARG REPO_PAXML=https://github.com/google/paxml.git -ARG REPO_PRAXIS=https://github.com/google/praxis.git -ARG REF_PAXML=main -ARG REF_PRAXIS=main ARG SRC_PATH_PAXML=/opt/paxml ARG SRC_PATH_PRAXIS=/opt/praxis +ARG SRC_PATH_TFTEXT=/opt/tensorflow-text +ARG SRC_PATH_LINGVO=/opt/lingvo ############################################################################### ## build tensorflow-text and lingvo, which do not have working arm64 pip wheels @@ -24,13 +22,11 @@ RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazeli #------------------------------------------------------------------------------ FROM wheel-builder as tftext-builder - -RUN <<"EOT" bash -exu -set -o pipefail +ARG SRC_PATH_TFTEXT +RUN <<"EOT" bash -exu -o pipefail pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.13.0 -git clone http://github.com/tensorflow/text.git /opt/tensorflow-text -cd /opt/tensorflow-text -git checkout v2.13.0 +get-source.sh -l tensorflow-text -m ${MANIFEST_FILE} +cd ${SRC_PATH_TFTEXT} ./oss_scripts/run_build.sh EOT @@ -39,17 +35,15 @@ EOT #------------------------------------------------------------------------------ FROM wheel-builder as lingvo-builder -ARG REPO_LINGVO=https://github.com/tensorflow/lingvo.git -ARG REF_LINGVO=master -ARG SRC_PATH_LINGVO=/opt/lingvo +ARG SRC_PATH_TFTEXT +ARG SRC_PATH_LINGVO -COPY --from=tftext-builder /opt/tensorflow-text/tensorflow_text*.whl /opt/ +COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ -RUN get-source.sh -f ${REPO_LINGVO} -r ${REF_LINGVO} -d ${SRC_PATH_LINGVO} +RUN get-source.sh -l lingvo -m ${MANIFEST_FILE} # build lingvo -RUN <<"EOT" bash -exu -set -o pipefail +RUN <<"EOT" bash -exu -o pipefail pushd ${SRC_PATH_LINGVO} git fetch origin pull/329/head:pr329 @@ -90,30 +84,27 @@ EOT ARG BASE_IMAGE FROM ${BASE_IMAGE} as mealkit -ARG REPO_PAXML -ARG REPO_PRAXIS -ARG REF_PAXML -ARG REF_PRAXIS ARG SRC_PATH_PAXML ARG SRC_PATH_PRAXIS +ARG SRC_PATH_TFTEXT COPY --from=lingvo-builder /tmp/lingvo/dist/lingvo*linux_aarch64.whl /opt/ -RUN echo "lingvo @ file://$(ls /opt/lingvo*.whl)" >> /opt/pip-tools.d/manifest.pax +RUN echo "lingvo @ file://$(ls /opt/lingvo*.whl)" >> /opt/pip-tools.d/requirements-paxml.in -COPY --from=tftext-builder /opt/tensorflow-text/tensorflow_text*.whl /opt/ -RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/manifest.pax +COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ +RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/requirements-paxml.in # paxml + praxis RUN <<"EOT" bash -ex -echo "tensorflow==2.13.0" >> /opt/pip-tools.d/manifest.pax -echo "tensorflow_datasets==4.9.2" >> /opt/pip-tools.d/manifest.pax -echo "chex==0.1.7" >> /opt/pip-tools.d/manifest.pax -echo "auditwheel" >> /opt/pip-tools.d/manifest.pax - -get-source.sh -f ${REPO_PAXML} -r ${REF_PAXML} -d ${SRC_PATH_PAXML} -get-source.sh -f ${REPO_PRAXIS} -r ${REF_PRAXIS} -d ${SRC_PATH_PRAXIS} -echo "-e file://${SRC_PATH_PAXML}[gpu]" >> /opt/pip-tools.d/manifest.pax -echo "-e file://${SRC_PATH_PRAXIS}" >> /opt/pip-tools.d/manifest.pax +echo "tensorflow==2.13.0" >> /opt/pip-tools.d/requirements-paxml.in +echo "tensorflow_datasets==4.9.2" >> /opt/pip-tools.d/requirements-paxml.in +echo "chex==0.1.7" >> /opt/pip-tools.d/requirements-paxml.in +echo "auditwheel" >> /opt/pip-tools.d/requirements-paxml.in + +get-source.sh -l paxml -m ${MANIFEST_FILE} +get-source.sh -l praxis -m ${MANIFEST_FILE} +echo "-e file://${SRC_PATH_PAXML}[gpu]" >> /opt/pip-tools.d/requirements-paxml.in +echo "-e file://${SRC_PATH_PRAXIS}" >> /opt/pip-tools.d/requirements-paxml.in for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do pushd ${src} diff --git a/.github/container/Dockerfile.t5x b/.github/container/Dockerfile.t5x index bcb1c790b..a5c43874d 100644 --- a/.github/container/Dockerfile.t5x +++ b/.github/container/Dockerfile.t5x @@ -1,8 +1,6 @@ # syntax=docker/dockerfile:1-labs ARG BASE_IMAGE=ghcr.io/nvidia/jax:mealkit -ARG REPO_T5X=https://github.com/google-research/t5x.git -ARG REF_T5X=main ARG SRC_PATH_T5X=/opt/t5x ############################################################################### @@ -11,13 +9,11 @@ ARG SRC_PATH_T5X=/opt/t5x FROM ${BASE_IMAGE} as mealkit -ARG REPO_T5X -ARG REF_T5X ARG SRC_PATH_T5X RUN <<"EOF" bash -ex -get-source.sh -f ${REPO_T5X} -r ${REF_T5X} -d ${SRC_PATH_T5X} -echo "-e file://${SRC_PATH_T5X}[gpu]" >> /opt/pip-tools.d/manifest.t5x +get-source.sh -l t5x -m ${MANIFEST_FILE} +echo "-e file://${SRC_PATH_T5X}[gpu]" >> /opt/pip-tools.d/requirements-t5x.in # remove head-of-tree specs from select dependencies pushd ${SRC_PATH_T5X} diff --git a/.github/container/bump.sh b/.github/container/bump.sh new file mode 100755 index 000000000..439f7fbb8 --- /dev/null +++ b/.github/container/bump.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +## Parse command-line arguments + +usage() { +cat < /dev/null && pwd ) + +MANIFEST_IN=${MANIFEST_IN:-} +MANIFEST_OUT=${MANIFEST_OUT:-} +ONLY_BUMP_PATCHES=${ONLY_BUMP_PATCHES:-0} + +if [[ -z "$MANIFEST_IN" ]]; then + echo "Need to provide a value for -i/--input-manifest" + usage 1 +fi + +if [[ -z "$MANIFEST_OUT" ]]; then + # Perform the update in place + MANIFEST_OUT=$MANIFEST_IN +else + # Write to a new file + cp $MANIFEST_IN $MANIFEST_OUT +fi + +for pkg in $(yq e 'keys | .[]' $MANIFEST_OUT); do + mode=$(yq e ".${pkg}.mode" $MANIFEST_OUT) + if [[ $mode == git-clone || $mode == pip-vcs ]] && [[ $ONLY_BUMP_PATCHES -eq 0 ]]; then + url=$(yq e ".${pkg}.url" $MANIFEST_OUT) + tracking_ref=$(yq e ".${pkg}.tracking_ref" $MANIFEST_OUT) + new_ref=$(git ls-remote $url $tracking_ref | awk '{print $1}') + yq e ".${pkg}.ref = \"$new_ref\"" -i $MANIFEST_OUT + fi + + has_patches=$(yq e ".${pkg} | has(\"patches\")" $MANIFEST_OUT) + if [[ $mode == git-clone && $has_patches == "true" ]]; then + url=$(yq e ".${pkg}.url" $MANIFEST_OUT) + repo_tmp=$(mktemp -d /tmp/${pkg}.XXXXXX) + git clone $url $repo_tmp + # Skip apply to defer to allow building upstream t5x and rosetta t5x + $SCRIPT_DIR/create-distribution.sh \ + --manifest $MANIFEST_OUT \ + --override_dir $repo_tmp \ + --package ${pkg} \ + --skip-apply + rm -rf $repo_tmp + fi +done diff --git a/rosetta/create-distribution.sh b/.github/container/create-distribution.sh similarity index 59% rename from rosetta/create-distribution.sh rename to .github/container/create-distribution.sh index a4db14ef3..3fbb8cef6 100755 --- a/rosetta/create-distribution.sh +++ b/.github/container/create-distribution.sh @@ -5,40 +5,52 @@ usage() { cat < with all entries -replaced with local patches. +-------------- -Relationship between --dir, --extra-dir, and --mirror-url repo args: - --dir: The upstream repo, locally cloned - --mirror-url: A mirror of the upstream repo - --extra-dir: A locally cloned mirror of the upstream repo. Helpful to incorporate changes from private repos. +This script has two modes of operation: + 1. $0 --skip-apply ... + 2. $0 ... -Patches in the --patchlist will be applied from the repos above according to the following rules: +Assuming you have: +t5x: + patches: + pull/3340/head: file://patches/t5x/pull-3340-head.patch + +(1) looks at the tracking-refs (pull/3340/head) of the patch and updates the local patch and the filename in the manifest (file://patches/t5x/pull-3340-head.patch) +(2) looks only at the filename value (file://patches/t5x/pull-3340-head.patch) and applies it + +-------------- + +The manifest can contain three versions of the repo: + url: The upstream repo, locally cloned + mirror_url: A miror of the upstream repo + extra_dir: Absolute path of locally cloned mirror of the upstream repo. Helpful to incorporate changes from private repos + +This script will in-place replace the patches in the --manifest with local patches. +Patches will be applied from the repos (if --skip-apply not set) above according to the following rules: Local patches (relative to this file): * ^file://.* - --dir: + url: * ^pull/.* - --mirror-url: + mirror_url: * ^mirror/.* * ^mirror/pull/.* - --extra-dir: + extra_dir: * Anything else EOF exit $1 } -args=$(getopt -o d:e:hm:p:r: --long dir:,extra-dir:,help,mirror-url:,patchlist:,ref: -- "$@") +args=$(getopt -o chm:p:s --long clean,help,manifest:,override_dir:,package:,skip-apply -- "$@") if [[ $? -ne 0 ]]; then echo usage 1 @@ -47,29 +59,29 @@ fi eval set -- "$args" while [ : ]; do case "$1" in - -d | --dir) - INSTALLED_DIR="$2" - shift 2 - ;; - -e | --extra-dir) - EXTRA_DIR="$2" - shift 2 + -c | --clean) + CLEAN_PATCHES=1 + shift 1 ;; -h | --help) usage ;; - -m | --mirror-url) - MIRROR_GIT_URL="$2" + -m | --manifest) + MANIFEST=$(readlink -f "$2") shift 2 ;; - -p | --patchlist) - PATCH_LIST=$(readlink -f "$2") + -o | --override_dir) + OVERRIDE_INSTALL_DIR="$2" shift 2 ;; - -r | --ref) - DISTRIBUTION_BASE_REF="$2" + -p | --package) + PACKAGE="$2" shift 2 ;; + -s | --skip-apply) + SKIP_APPLY=1 + shift 1 + ;; --) shift; break @@ -83,24 +95,35 @@ if [[ $# -ge 1 ]]; then usage 1 fi -set -euox pipefail -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +set -eou pipefail +# readlink -f $(pwd) is cross-platform way to ensure /tmp gets resolved correctly on macos +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && readlink -f $(pwd) ) -INSTALLED_DIR=${INSTALLED_DIR:-/opt/t5x} -DISTRIBUTION_BASE_REF=${DISTRIBUTION_BASE_REF:-HEAD} -MIRROR_GIT_URL=${MIRROR_GIT_URL:-https://github.com/nvjax-svc-0/t5x.git} -GEN_PATCH_DIR=${GEN_PATCH_DIR:-$SCRIPT_DIR/patches/$(basename $(git -C $INSTALLED_DIR remote get-url origin) .git)} -GEN_PATCH_LIST=$GEN_PATCH_DIR/$(basename $PATCH_LIST).gen -if [[ -e $GEN_PATCH_LIST ]]; then - echo "[WARNING]: $GEN_PATCH_LIST exists and will be overwritten" - rm -f $GEN_PATCH_LIST +if [[ -z "$MANIFEST" || -z "$PACKAGE" ]]; then + echo "--manifest and --package must be provided" + usage 1 fi -mkdir -p $GEN_PATCH_DIR -if [[ -z "${INSTALLED_DIR}" ]]; then - echo "[ERROR]: Need to specify -d/--dir" - usage 1 +BASE_DIR=${BASE_DIR:-/opt} +CLEAN_PATCHES=${CLEAN_PATCHES:-0} +UPSTREAM_URL=$(yq e ".${PACKAGE}.url" $MANIFEST) +# The tracking_ref is interpreted as the default "main" branch and all patches are +# assumed to be rooted on a sha on the tracking_ref's history +TRACKING_REF=$(yq e ".${PACKAGE}.tracking_ref" $MANIFEST) +INSTALLED_DIR=${OVERRIDE_INSTALL_DIR:-/opt/${PACKAGE}} +MIRROR_GIT_URL=$(yq e ".${PACKAGE}.mirror_url // \"\"" $MANIFEST) +EXTRA_DIR=$(yq e ".${PACKAGE}.extra_dir // \"\"" $MANIFEST) + +SKIP_APPLY=${SKIP_APPLY:-0} +GEN_PATCH_DIR=${GEN_PATCH_DIR:-$SCRIPT_DIR/patches/$PACKAGE} +# Associative arrays aren't available before bash <4.0, so maintaining separate key/value arrays +PATCH_KEYS=() +PATCH_VALUES=() +if [[ $CLEAN_PATCHES -eq 1 ]]; then + echo "--clean provided, so deleting $GEN_PATCH_DIR" + rm -rf $GEN_PATCH_DIR fi +mkdir -p $GEN_PATCH_DIR cd ${INSTALLED_DIR} @@ -121,19 +144,19 @@ done # Since PR may be rooted on a future commit on main (C), if we run "git merge-base origin/main D", we will get B # instead of C, which would cause us to cherry-pick future upstream commits. So, we will fetch the latest # upstream main to prevent this. -git fetch origin main +git fetch origin $TRACKING_REF -echo "[INFO]: Basing distribution on git-ref: ${DISTRIBUTION_BASE_REF} ($(git rev-parse ${DISTRIBUTION_BASE_REF}))" -# previous-HEAD's purpose is to point to the state of the repo before any changes are made whereas -# distribution-base is to point to the commit where we want to begin building the distribution on. -# Most of the time it will be the same, but it can be different. +# previous-HEAD's purpose is to point to the state of the repo before any distribution changes are made +# We do not rely on the manifest.yaml's .${library}.ref because local commits may be made on top by the upstream docker builds if ! git rev-parse --verify previous-HEAD >/dev/null 2>&1; then + echo "[INFO]: Basing distribution on HEAD ($(git rev-parse HEAD)) and marking that with the local branch: previous-HEAD" git branch --force previous-HEAD HEAD else + echo "[INFO]: Basing distribution on ref: previous-HEAD ($(git rev-parse previous-HEAD))" git switch previous-HEAD fi # Create a local branch to mark the base commit -git branch --force distribution-base ${DISTRIBUTION_BASE_REF} +git branch --force distribution-base previous-HEAD # Create a local branch for the distribution that starts from the base git branch --force rosetta-distribution distribution-base git switch rosetta-distribution @@ -172,9 +195,9 @@ if [[ -n "${EXTRA_DIR+x}" ]] && [[ -d ${EXTRA_DIR} ]]; then git remote add -f ${EXTRA_REMOTE_NAME} ${EXTRA_DIR} fi -################# -# Apply patches # -################# +#################### +# Helper Functions # +#################### fork-point() { main=$1 feat_branch=$2 @@ -189,14 +212,24 @@ fork-point() { merge_commit=$(git rev-list --ancestry-path ${feat_branch}..${main} | tail -n1) git merge-base ${merge_commit}^ ${feat_branch}^ } -# git-am + adds to generated patchlist -am+record() { +# Applies git-am and returns the local patch URI +apply-local-patch() { + # This is the associated array key used to update the patchlist + patch_name=$1 # Canonicalize path to remove extra slashes or dot syntax - patch_path=$(readlink -f $1) + patch_path=$(readlink -f $2) if [[ ! $patch_path =~ ^${SCRIPT_DIR} ]]; then echo "[ERROR]: patch_path=$patch_path should start with $SCRIPT_DIR" exit 1 fi + # Create a new generated patchlist (for reproducibility) + PATCH_KEYS+=($patch_name) + PATCH_VALUES+=("file://${patch_path#$SCRIPT_DIR/}") + if [[ "$SKIP_APPLY" -eq 1 ]]; then + echo "[INFO]: Skipping patch application: $patch_path" + return + fi + # Apply the patch git am --3way <$patch_path || ret_code=$? if [[ ${ret_code:-0} -ne 0 ]]; then @@ -210,12 +243,11 @@ $(git diff) EOF exit 1 fi - # Create a new generated patchlist (for reproducibility) - echo "file://${patch_path#$SCRIPT_DIR/}" >> $GEN_PATCH_LIST } -apply-patches() { - from=$1 - to=$2 +apply-ref-patches() { + patch_name=$1 + from=$2 + to=$3 # Normally we'd apply the changes with git cherry-pick, but we need to check if there are merge commits num_merge_commits=$(git rev-list --min-parents=2 --count $from..$to) if [[ $num_merge_commits -gt 0 ]]; then @@ -236,57 +268,69 @@ apply-patches() { git branch -D ${to_linear} fi # Apply the patch - am+record $GEN_PATCH_DIR/$patch_fname + apply-local-patch $patch_name $GEN_PATCH_DIR/$patch_fname } -MIRROR_REMOTE_NAME=mirror -if git remote show ${MIRROR_REMOTE_NAME} &>/dev/null; then - git remote remove ${MIRROR_REMOTE_NAME} +if [[ -n "${MIRROR_GIT_URL}" ]] ; then + MIRROR_REMOTE_NAME=mirror + if git remote show ${MIRROR_REMOTE_NAME} &>/dev/null; then + git remote remove ${MIRROR_REMOTE_NAME} + fi + git remote add -f ${MIRROR_REMOTE_NAME} ${MIRROR_GIT_URL} fi -git remote add -f ${MIRROR_REMOTE_NAME} ${MIRROR_GIT_URL} + +################# +# Apply patches # +################# IFS=$'\n' -for line in $(cat ${PATCH_LIST}); do - if [[ "${line}" =~ ^[[:blank:]]*$ ]] || [[ "${line}" =~ ^[[:blank:]]*\# ]]; then - continue - fi - git_ref=$(awk '{print $1}' <<< "${line}") - if [[ "${git_ref}" =~ ^file:// ]]; then - patch_path=$SCRIPT_DIR/${git_ref#file://} +for git_ref in $(yq e ".${PACKAGE}.patches | keys | .[]" $MANIFEST); do + if [[ $SKIP_APPLY -eq 0 ]]; then + # If we apply, then use the value, not the key + patch_uri=$(yq e ".${PACKAGE}.patches.${git_ref}" $MANIFEST) + patch_path=$SCRIPT_DIR/${patch_uri#file://} if [[ ! -f $patch_path ]]; then echo "[ERROR]: ${git_ref} refers to $patch_path which does not exist" exit 1 fi - am+record $patch_path + apply-local-patch $git_ref $patch_path continue elif [[ "${git_ref}" =~ ^pull/ ]]; then REMOTE_NAME=origin PR_ID=$(cut -d/ -f2 <<<"${git_ref}") branch=PR-${PR_ID} git fetch ${REMOTE_NAME} ${git_ref}:${branch} - main_branch=${REMOTE_NAME}/main + main_branch=${REMOTE_NAME}/${TRACKING_REF} elif [[ "${git_ref}" =~ ^mirror/pull/ ]]; then + if [[ -z "${MIRROR_GIT_URL}" ]] ; then + echo "[Error]: MIRROR_GIT_URL not provided so cannot apply patch=${git_ref}" + exit 1 + fi REMOTE_NAME=${MIRROR_REMOTE_NAME} PR_ID=$(cut -d/ -f3 <<<"${git_ref}") branch=PR-${PR_ID} git fetch ${REMOTE_NAME} $(cut -d/ -f2- <<<${git_ref}):${branch} - main_branch=${REMOTE_NAME}/main + main_branch=${REMOTE_NAME}/${TRACKING_REF} elif [[ "${git_ref}" =~ ^mirror/ ]]; then + if [[ -z "${MIRROR_GIT_URL}" ]] ; then + echo "[Error]: MIRROR_GIT_URL not provided so cannot apply patch=${git_ref}" + exit 1 + fi REMOTE_NAME=${MIRROR_REMOTE_NAME} # REMOTE_NAME not needed b/c git_ref already prefixed branch=${git_ref} - main_branch=${REMOTE_NAME}/main + main_branch=${REMOTE_NAME}/${TRACKING_REF} else if [[ -z "${EXTRA_DIR+x}" ]] || [[ ! -d ${EXTRA_DIR} ]]; then - echo "[WARNING]: EXTRA_DIR=${EXTRA_DIR} does not exist so cannot apply patch=${git_ref}" - continue + echo "[Error]: EXTRA_DIR=${EXTRA_DIR} does not exist so cannot apply patch=${git_ref}" + exit 1 fi REMOTE_NAME=${EXTRA_REMOTE_NAME} # Fetch both the feature branch and main so that we can cherry pick the entire branch branch=${REMOTE_NAME}/${git_ref}${TMP_BRANCH_SUFFIX} # Use main-tmp-rosetta instead of main b/c remote branches may have been updated and the local main is stale - main_branch=${REMOTE_NAME}/main${TMP_BRANCH_SUFFIX} + main_branch=${REMOTE_NAME}/${TRACKING_REF}${TMP_BRANCH_SUFFIX} fi fork_point=$(fork-point ${main_branch} ${branch}) - apply-patches ${fork_point} ${branch} || ret_code=$? + apply-ref-patches ${git_ref} ${fork_point} ${branch} || ret_code=$? if [[ ${ret_code:-0} -ne 0 ]]; then cat </dev/null; then git remote remove ${remote} diff --git a/.github/container/get-source.sh b/.github/container/get-source.sh index b8fff71a7..afa96d9b7 100755 --- a/.github/container/get-source.sh +++ b/.github/container/get-source.sh @@ -1,7 +1,7 @@ #!/bin/bash -## Fetch a Python package from a git repo and write the pip-tools input manifest to stdout +## Clone a git repo and write the pip-compile input to stdout ## Example: -## get-source.sh -f https://github.com/google/flax.git -r main -d /opt/flax +## get-source.sh -m manifest.yaml -l flax ## Output: ## -e /opt/flax @@ -9,47 +9,46 @@ usage() { echo "Usage: $0 [OPTION]..." - echo " -d, --dir PATH [Required] Local path to check out the source code." - echo " -f, --from URL [Required] URL of the source repo." + echo " -b, --base-dir DIR Directory to install package under. Default /opt" echo " -h, --help Print usage." - echo " -m, --manifest FILE Create a pip manifest file if specified" - echo " -r, --ref REF Git commit SHA, branch name, or tag name to checkout. Uses default branch if not specified." + echo " -l, --library LIB The library to clone, e.g., jax, flax, t5x" + echo " -m, --manifest FILE The JAX-Toolbox manifest yaml file" + echo " -o, --out-requirements Create a pip manifest file if specified" echo exit $1 } -args=$(getopt -o d:f:hm:r: --long dir:,from:,help,manifest:,ref: -- "$@") +args=$(getopt -o b:hl:m:o: --long base-dir:,help,library:,manifest:,out-requirements: -- "$@") if [[ $? -ne 0 ]]; then exit 1 fi ## Set default arguments -GIT_REPO="" -GIT_REF="${GIT_REF:-HEAD}" -INSTALL_DIR="" -MANIFEST_FILE="" +BASE_INSTALL_DIR="/opt" +MANIFEST="" +OUT_REQUIREMENTS_FILE="" eval set -- "$args" while [ : ]; do case "$1" in - -d | --dir) - INSTALL_DIR="$2" - shift 2 - ;; - -f | --from) - GIT_REPO="$2" + -b | --base-dir) + BASE_INSTALL_DIR=$(readlink -f "$2") shift 2 ;; -h | --help) usage ;; + -l | --library) + LIBRARY="$2" + shift 2 + ;; -m | --manifest) - MANIFEST_FILE="$2" + MANIFEST="$2" shift 2 ;; - -r | --ref) - GIT_REF="$2" + -o | --out-requirements) + OUT_REQUIREMENTS_FILE="$2" shift 2 ;; --) @@ -60,21 +59,30 @@ while [ : ]; do done if [[ $# -ge 1 ]]; then - echo "Un-recognized argument: $*" && echo + echo "Un-recognized argument: $*" usage 1 fi -if [[ ! -n "${GIT_REPO}" ]]; then - echo "Source repository not speicified." && echo +if [[ -z "${LIBRARY}" ]]; then + echo "Library not specified." usage 1 fi -if [[ ! -n "${INSTALL_DIR}" ]]; then - echo "Check out destination not specified." && echo +if [[ -z "${MANIFEST}" ]]; then + echo "Manifest not specified." usage 1 fi ## check out the source +PACKAGE_MODE=$(yq e ".${LIBRARY}.mode" $MANIFEST) +if [[ "${PACKAGE_MODE}" != "git-clone" ]]; then + echo "--library=${LIBRARY} mode is ${PACKAGE_MODE} which is not meant to be cloned. Update mode to \"git-clone\" if this repo should be cloned" + exit 1 +fi + +GIT_REPO=$(yq e ".${LIBRARY}.url" $MANIFEST) +GIT_REF=$(yq e ".${LIBRARY}.ref" $MANIFEST) +INSTALL_DIR=${BASE_INSTALL_DIR}/$LIBRARY echo "Fetching $GIT_REPO#$GIT_REF to $INSTALL_DIR" @@ -87,5 +95,5 @@ git submodule init git submodule update --recursive popd -echo "Writing to ${MANIFEST_FILE}:" -echo "-e file://${INSTALL_DIR}" | tee -a ${MANIFEST_FILE} +echo "Writing to ${OUT_REQUIREMENTS_FILE}:" +echo "-e file://${INSTALL_DIR}" | tee -a ${OUT_REQUIREMENTS_FILE} diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml new file mode 100644 index 000000000..6077254ad --- /dev/null +++ b/.github/container/manifest.yaml @@ -0,0 +1,102 @@ +# Updated in: XXX +jax: + url: https://github.com/google/jax.git + tracking_ref: main + ref: b032a0271e3e2ea8d0df64d2f3f1a1e450a38dc9 # 2023-11-15 + mode: git-clone +xla: + url: https://github.com/openxla/xla.git + tracking_ref: main + ref: 8fb606ffa03c030035d6c0c9d05534dbf6701906 # 2023-11-15 + mode: git-clone +flax: + url: https://github.com/google/flax.git + mirror_url: https://github.com/nvjax-svc-0/flax.git + tracking_ref: main + ref: a572f6af2fef565c0f9ba2fc12b781e9e3385140 + mode: git-clone + patches: + pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules +transformer-engine: + url: https://github.com/NVIDIA/TransformerEngine.git + tracking_ref: main + ref: d76118d90df0422d52261adc26a5f4351a1dd71f + mode: git-clone +t5x: + url: https://github.com/google-research/t5x.git + mirror_url: https://github.com/nvjax-svc-0/t5x.git + tracking_ref: main + ref: c39a33a35bb2f03f6d36455e6378620c6634a995 + mode: git-clone + patches: + mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore + mirror/patch/dali-support: file://patches/t5x/mirror-patch-dali-support.patch # pull/1393/head # https://github.com/google-research/t5x/pull/1393: Adds DALI support to t5x + mirror/patch/t5x_te_in_contrib_noindent: file://patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100) +paxml: + url: https://github.com/google/paxml.git + mirror_url: https://github.com/nvjax-svc-0/paxml.git + tracking_ref: main + ref: 6c811d5e8f82a8aa75530b50223302d98f47e984 + mode: git-clone + patches: + pull/46/head: file://patches/paxml/PR-46.patch # adds Transformer Engine support +praxis: + url: https://github.com/google/praxis.git + mirror_url: https://github.com/nvjax-svc-0/praxis.git + tracking_ref: main + ref: fcadc09773e32a18abd5b0240e07da33316a9636 + mode: git-clone + patches: + pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas. +lingvo: + # Used only in ARM pax builds + url: https://github.com/tensorflow/lingvo.git + tracking_ref: master + ref: 36a1e314864533eeb1cc1e6590e86c10c03b1516 + mode: git-clone +tensorflow-text: + # Used only in ARM pax builds + url: https://github.com/tensorflow/text.git + tracking_ref: v2.13.0 + ref: 917a681d7220ebf9b62a08b6f9ce7b7db886ddef + mode: git-clone +pydantic: + version: X.Y.Z + mode: pip-constraint +# Used by praxis +fiddle: + url: https://github.com/google/fiddle.git + tracking_ref: main + ref: b13db6481720bc897f4efd7e04c7ba4f5907ce74 + mode: pip-vcs +# Used by t5x +airio: + url: https://github.com/google/airio.git + tracking_ref: main + ref: 69b3ec4ded478ad9cacdc97652a9d086a6a644c4 + mode: pip-vcs +clu: + url: https://github.com/google/CommonLoopUtils.git + tracking_ref: main + ref: 7ba2a9d83a3bc1a97b59482c2f02dc4b3614bc31 + mode: pip-vcs +dllogger: + url: https://github.com/NVIDIA/dllogger.git + tracking_ref: master + ref: 0540a43971f4a8a16693a9de9de73c1072020769 + mode: pip-vcs +jestimator: + url: https://github.com/google-research/jestimator.git + tracking_ref: master + ref: fa143d93e337ca8ab77c4510baf21ae52af24ab2 + mode: pip-vcs +optax: + url: https://github.com/deepmind/optax.git + tracking_ref: master + ref: bf987e15eacf6efeb1a1a51b8868c094c3a15f9b + mode: pip-vcs +seqio: + url: https://github.com/google/seqio.git + tracking_ref: main + ref: 515d917bf58da4103a2bbf39c3716213c36aff03 + mode: pip-vcs diff --git a/.github/container/patches/flax/PR-3340.patch b/.github/container/patches/flax/PR-3340.patch new file mode 100644 index 000000000..75a06d680 --- /dev/null +++ b/.github/container/patches/flax/PR-3340.patch @@ -0,0 +1,440 @@ +From 6b8bcac6234f156a763f7e535670cb094509c350 Mon Sep 17 00:00:00 2001 +From: ashors1 +Date: Fri, 2 Jun 2023 15:01:21 -0700 +Subject: [PATCH 1/2] add t5x sharding annotations to flax layers + +--- + flax/linen/attention.py | 33 +++++++++++++++++++++++------ + flax/linen/linear.py | 41 ++++++++++++++++++++++++++++--------- + flax/linen/normalization.py | 25 ++++++++++++++++++---- + 3 files changed, 79 insertions(+), 20 deletions(-) + +diff --git a/flax/linen/attention.py b/flax/linen/attention.py +index f5a388a2..20537921 100644 +--- a/flax/linen/attention.py ++++ b/flax/linen/attention.py +@@ -32,6 +32,7 @@ from flax.linen.linear import ( + ) + from flax.linen.module import Module, compact, merge_param + from flax.linen.normalization import LayerNorm ++from flax.linen.partitioning import variable_with_axes + + PRNGKey = jax.Array + Shape = Tuple[int, ...] +@@ -223,6 +224,17 @@ class MultiHeadDotProductAttention(Module): + num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. + normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442). ++ in_proj_kernel_axes: a tuple of axes over which to shard the kernel for ++ the attention in-projection. ++ in_proj_bias_axes: a tuple of axis names associated with the bias for ++ the attention in-projection. ++ out_proj_kernel_axes: a tuple of axis names associated with the kernel for ++ the attention out-projection. ++ out_proj_bias_axes: a tuple of axis names associated with the bias for ++ the attention out-projection. ++ decode_axes: a tuple of axis names associated with auroregressive cache. ++ Only used when decode=True. ++ + """ + + num_heads: int +@@ -247,6 +259,11 @@ class MultiHeadDotProductAttention(Module): + out_dot_general: Optional[DotGeneralT] = None + qkv_dot_general_cls: Any = None + out_dot_general_cls: Any = None ++ in_proj_kernel_axes: Tuple[str, ...] = None ++ in_proj_bias_axes: Tuple[str, ...] = None ++ out_proj_kernel_axes: Tuple[str, ...] = None ++ out_proj_bias_axes: Tuple[str, ...] = None ++ decode_axes: Tuple[str, ...] = None + + @overload + def __call__( +@@ -378,6 +395,8 @@ class MultiHeadDotProductAttention(Module): + precision=self.precision, + dot_general=self.qkv_dot_general, + dot_general_cls=self.qkv_dot_general_cls, ++ kernel_axes=self.in_proj_kernel_axes, ++ bias_axes=self.in_proj_bias_axes, + ) + # project inputs_q to multi-headed q/k/v + # dimensions are then [batch..., length, n_heads, n_features_per_head] +@@ -398,14 +417,14 @@ class MultiHeadDotProductAttention(Module): + if self.decode: + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable('cache', 'cached_key') +- cached_key = self.variable( +- 'cache', 'cached_key', jnp.zeros, key.shape, key.dtype ++ cached_key = variable_with_axes('cache', 'cached_key', ++ jnp.zeros, key.shape, key.dtype, axes=self.decode_axes + ) +- cached_value = self.variable( +- 'cache', 'cached_value', jnp.zeros, value.shape, value.dtype ++ cached_value = variable_with_axes('cache', 'cached_value', ++ jnp.zeros, value.shape, value.dtype, axes=self.decode_axes + ) +- cache_index = self.variable( +- 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) ++ cache_index = variable_with_axes('cache', 'cache_index', ++ lambda: jnp.array(0, dtype=jnp.int32), axes=None + ) + if is_initialized: + ( +@@ -483,6 +502,8 @@ class MultiHeadDotProductAttention(Module): + dot_general=self.out_dot_general, + dot_general_cls=self.out_dot_general_cls, + name='out', # type: ignore[call-arg] ++ kernel_axes=self.out_proj_kernel_axes, ++ bias_axes=self.out_proj_bias_axes, + )(x) + return out + +diff --git a/flax/linen/linear.py b/flax/linen/linear.py +index f4afd805..999acf2c 100644 +--- a/flax/linen/linear.py ++++ b/flax/linen/linear.py +@@ -36,6 +36,7 @@ from flax.core import meta + from flax.linen import initializers + from flax.linen.dtypes import promote_dtype + from flax.linen.module import Module, compact ++from flax.linen.partitioning import param_with_axes + + PRNGKey = Any + Shape = Tuple[int, ...] +@@ -81,6 +82,8 @@ class DenseGeneral(Module): + bias_init: initializer function for the bias. + precision: numerical precision of the computation see `jax.lax.Precision` + for details. ++ kernel_axes: a tuple of axes associated with the kernel. ++ bias_axes: a tuple of axes associated with the bias. + """ + + features: Union[int, Sequence[int]] +@@ -97,6 +100,8 @@ class DenseGeneral(Module): + # Deprecated. Will be removed. + dot_general: Optional[DotGeneralT] = None + dot_general_cls: Any = None ++ kernel_axes: Tuple[str, ...] = None ++ bias_axes: Tuple[str, ...] = None + + @compact + def __call__(self, inputs: Array) -> Array: +@@ -145,8 +150,9 @@ class DenseGeneral(Module): + if ax not in axis + ) + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features +- kernel = self.param( +- 'kernel', kernel_init_wrap, batch_shape + kernel_shape, self.param_dtype ++ kernel = param_with_axes( ++ 'kernel', kernel_init_wrap, batch_shape + kernel_shape, ++ self.param_dtype, axes=self.kernel_axes + ) + + batch_ind = tuple(range(n_batch_dims)) +@@ -164,9 +170,11 @@ class DenseGeneral(Module): + return meta.replace_boxed(bias, jnp.reshape(bias.unbox(), shape)) + return jnp.reshape(bias, shape) + +- bias = self.param( +- 'bias', bias_init_wrap, batch_shape + features, self.param_dtype ++ bias = param_with_axes( ++ 'bias', bias_init_wrap, batch_shape + features, ++ self.param_dtype, axes=self.bias_axes + ) ++ + else: + bias = None + +@@ -204,6 +212,8 @@ class Dense(Module): + for details. + kernel_init: initializer function for the weight matrix. + bias_init: initializer function for the bias. ++ kernel_axes: a tuple of axes associated with the kernel. ++ bias_axes: a tuple of axes associated with the bias. + """ + + features: int +@@ -218,6 +228,8 @@ class Dense(Module): + # Deprecated. Will be removed. + dot_general: Optional[DotGeneralT] = None + dot_general_cls: Any = None ++ kernel_axes: Tuple[str, ...] = None ++ bias_axes: Tuple[str, ...] = None + + @compact + def __call__(self, inputs: Array) -> Array: +@@ -229,15 +241,18 @@ class Dense(Module): + Returns: + The transformed input. + """ +- kernel = self.param( ++ kernel = param_with_axes( + 'kernel', + self.kernel_init, + (jnp.shape(inputs)[-1], self.features), + self.param_dtype, ++ axes=self.kernel_axes + ) + if self.use_bias: +- bias = self.param( +- 'bias', self.bias_init, (self.features,), self.param_dtype ++ bias = param_with_axes( ++ 'bias', self.bias_init, (self.features,), ++ self.param_dtype, ++ axes=self.bias_axes + ) + else: + bias = None +@@ -331,6 +346,8 @@ class _Conv(Module): + for details. + kernel_init: initializer for the convolutional kernel. + bias_init: initializer for the bias. ++ kernel_axes: a tuple of axes associated with the kernel. ++ bias_axes: a tuple of axes associated with the bias. + """ + + features: int +@@ -352,6 +369,8 @@ class _Conv(Module): + # Deprecated. Will be removed. + conv_general_dilated: Optional[ConvGeneralDilatedT] = None + conv_general_dilated_cls: Any = None ++ kernel_axes: Tuple[str, ...] = None ++ bias_axes: Tuple[str, ...] = None + + @property + def shared_weights(self) -> bool: # type: ignore +@@ -496,8 +515,10 @@ class _Conv(Module): + f'Shapes are: {self.mask.shape}, {kernel_shape}' + ) + +- kernel = self.param( +- 'kernel', self.kernel_init, kernel_shape, self.param_dtype ++ kernel = param_with_axes( ++ 'kernel', self.kernel_init, kernel_shape, ++ self.param_dtype, ++ axes=self.kernel_axes + ) + + if self.mask is not None: +@@ -511,7 +532,7 @@ class _Conv(Module): + # One bias weight per output entry, unshared betwen pixels. + bias_shape = conv_output_shape[1:] + +- bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype) ++ bias = param_with_axes('bias', self.bias_init, bias_shape, self.param_dtype, axes=self.bias_axes) + else: + bias = None + +diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py +index 076fd680..6eff2dd1 100644 +--- a/flax/linen/normalization.py ++++ b/flax/linen/normalization.py +@@ -24,6 +24,7 @@ from jax import lax + from jax.nn import initializers + + from flax.linen import dtypes, module, transforms ++from flax.linen.partitioning import param_with_axes + + PRNGKey = Any + Array = Any +@@ -152,6 +153,7 @@ def _normalize( + use_scale: bool, + bias_init: Callable[[PRNGKey, Shape, Dtype], Array], + scale_init: Callable[[PRNGKey, Shape, Dtype], Array], ++ axes: Tuple[str, ...] = None, + ): + """Normalizes the input of a normalization layer and optionally applies a learned scale and bias. + +@@ -171,6 +173,7 @@ def _normalize( + use_scale: If true, scale the output. + bias_init: Initialization function for the bias term. + scale_init: Initialization function for the scaling function. ++ axes: A tuple of axis names over which to shard parameters. + + Returns: + The normalized input. +@@ -189,15 +192,17 @@ def _normalize( + mul = lax.rsqrt(var + epsilon) + args = [x] + if use_scale: +- scale = mdl.param( +- 'scale', scale_init, reduced_feature_shape, param_dtype ++ scale = param_with_axes( ++ 'scale', scale_init, reduced_feature_shape, ++ param_dtype, axes=axes, module=mdl + ).reshape(feature_shape) + mul *= scale + args.append(scale) + y *= mul + if use_bias: +- bias = mdl.param( +- 'bias', bias_init, reduced_feature_shape, param_dtype ++ bias = param_with_axes( ++ 'bias', bias_init, reduced_feature_shape, ++ param_dtype, axes=axes, module=mdl + ).reshape(feature_shape) + y += bias + args.append(bias) +@@ -280,6 +285,7 @@ class BatchNorm(Module): + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. ++ pjit_axis_names: A tuple of axis names. + """ + + use_running_average: Optional[bool] = None +@@ -295,6 +301,7 @@ class BatchNorm(Module): + axis_name: Optional[str] = None + axis_index_groups: Any = None + use_fast_variance: bool = True ++ pjit_axis_name: Tuple[str, ...] = None + + @compact + def __call__(self, x, use_running_average: Optional[bool] = None, mask=None): +@@ -368,6 +375,7 @@ class BatchNorm(Module): + self.use_scale, + self.bias_init, + self.scale_init, ++ self.pjit_axis_name, + ) + + +@@ -405,6 +413,7 @@ class LayerNorm(Module): + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. ++ pjit_axis_names: A tuple of axis names. + """ + + epsilon: float = 1e-6 +@@ -419,6 +428,7 @@ class LayerNorm(Module): + axis_name: Optional[str] = None + axis_index_groups: Any = None + use_fast_variance: bool = True ++ pjit_axis_name: Tuple[str, ...] = None + + @compact + def __call__(self, x): +@@ -453,6 +463,7 @@ class LayerNorm(Module): + self.use_scale, + self.bias_init, + self.scale_init, ++ self.pjit_axis_name, + ) + + +@@ -497,6 +508,7 @@ class RMSNorm(Module): + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. ++ pjit_axis_names: A tuple of axis names. + """ + + epsilon: float = 1e-6 +@@ -508,6 +520,7 @@ class RMSNorm(Module): + feature_axes: Axes = -1 + axis_name: Optional[str] = None + axis_index_groups: Any = None ++ pjit_axis_name: Tuple[str, ...] = None + + @compact + def __call__(self, x): +@@ -542,6 +555,7 @@ class RMSNorm(Module): + self.use_scale, + initializers.zeros, + self.scale_init, ++ self.pjit_axis_name, + ) + + +@@ -582,6 +596,7 @@ class GroupNorm(Module): + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. ++ pjit_axis_names: A tuple of axis names. + """ + + num_groups: Optional[int] = 32 +@@ -596,6 +611,7 @@ class GroupNorm(Module): + axis_name: Optional[str] = None + axis_index_groups: Any = None + use_fast_variance: bool = True ++ pjit_axis_name: Tuple[str, ...] = None + + @compact + def __call__(self, x): +@@ -668,6 +684,7 @@ class GroupNorm(Module): + self.use_scale, + self.bias_init, + self.scale_init, ++ self.pjit_axis_name, + ) + + +-- +2.25.1 + + +From d1f3ec337b85b5c5377aab72d814adfc89dd4af5 Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Mon, 2 Oct 2023 16:10:05 -0700 +Subject: [PATCH 2/2] Added ConvTranspose sharding annotations (#3) + +Co-authored-by: sahilj +--- + flax/linen/linear.py | 24 ++++++++++++++++++++---- + 1 file changed, 20 insertions(+), 4 deletions(-) + +diff --git a/flax/linen/linear.py b/flax/linen/linear.py +index 999acf2c..8e031c77 100644 +--- a/flax/linen/linear.py ++++ b/flax/linen/linear.py +@@ -708,6 +708,21 @@ class ConvTranspose(Module): + ] = initializers.zeros_init() + transpose_kernel: bool = False + ++ def param_with_axes( ++ self, ++ name: str, ++ init_fn, ++ *init_args, ++ axes: Optional[Tuple[str, ...]] = None, ++ module: Optional[Module] = None): ++ return param_with_axes( ++ name, ++ init_fn, ++ *init_args, ++ axes=axes, ++ module=module, ++ ) ++ + @compact + def __call__(self, inputs: Array) -> Array: + """Applies a transposed convolution to the inputs. +@@ -764,8 +779,9 @@ class ConvTranspose(Module): + f'Shapes are: {self.mask.shape}, {kernel_shape}' + ) + +- kernel = self.param( +- 'kernel', self.kernel_init, kernel_shape, self.param_dtype ++ kernel = self.param_with_axes( ++ 'kernel', self.kernel_init, kernel_shape, self.param_dtype, ++ axes=('height', 'width', 'input', 'embed') + ) + + if self.mask is not None: +@@ -776,8 +792,8 @@ class ConvTranspose(Module): + padding_lax = 'VALID' + + if self.use_bias: +- bias = self.param( +- 'bias', self.bias_init, (self.features,), self.param_dtype ++ bias = self.param_with_axes( ++ 'bias', self.bias_init, (self.features,), self.param_dtype, axes=('embed', ) + ) + else: + bias = None +-- +2.25.1 + diff --git a/.github/container/patches/paxml/PR-46.patch b/.github/container/patches/paxml/PR-46.patch new file mode 100644 index 000000000..0f9879e4a --- /dev/null +++ b/.github/container/patches/paxml/PR-46.patch @@ -0,0 +1,1567 @@ +From 9cc67accfa9d2c3e4dd09fe58955730841aac309 Mon Sep 17 00:00:00 2001 +From: ashors1 +Date: Tue, 18 Jul 2023 10:27:03 -0700 +Subject: [PATCH 1/9] add TE support + +--- + paxml/contrib/gpu/scripts_gpu/configs.py | 26 +- + paxml/contrib/gpu/scripts_gpu/te_helper.py | 324 +++++++++++++++++++++ + paxml/main.py | 67 +++-- + paxml/tasks_lib.py | 3 +- + paxml/trainer_lib.py | 36 ++- + 5 files changed, 401 insertions(+), 55 deletions(-) + create mode 100644 paxml/contrib/gpu/scripts_gpu/te_helper.py + +diff --git a/paxml/contrib/gpu/scripts_gpu/configs.py b/paxml/contrib/gpu/scripts_gpu/configs.py +index 71542ad..50b848e 100644 +--- a/paxml/contrib/gpu/scripts_gpu/configs.py ++++ b/paxml/contrib/gpu/scripts_gpu/configs.py +@@ -21,6 +21,7 @@ from paxml import experiment_registry + from paxml import tasks_lib + from paxml.contrib.gpu.scripts_gpu.tasks import LambadaDataset + from paxml.contrib.gpu.scripts_gpu.tasks import PileUnsupervisedDataset ++from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper + from paxml.tasks.lm.params.c4 import TransformerLmSpmdAdam + from praxis import base_layer + from praxis import layers +@@ -107,7 +108,7 @@ class GPT126M(TransformerLmSpmdAdam): + + MAX_SEQ_LEN = 2048 + VOCAB_SIZE = 50304 +- PACKED_INPUT = True ++ PACKED_INPUT = False + PERCORE_BATCH_SIZE = 4 + + NUM_LAYERS = 12 +@@ -159,10 +160,21 @@ class GPT126M(TransformerLmSpmdAdam): + fdl.get_callable(stacked_p), transformers.StackedTransformerRepeated + ): + stacked_p = stacked_p.block +- transformer_layer_p = stacked_p.transformer_layer_params_tpl +- transformer_layer_p.ln_tpl.reductions_in_fp32 = True +- transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True ++ + task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True ++ if not TransformerEngineHelper.is_enabled_te(): ++ transformer_layer_p = stacked_p.transformer_layer_params_tpl ++ transformer_layer_p.ln_tpl.reductions_in_fp32 = True ++ transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True ++ else: ++ stacked_p = TransformerEngineHelper.get_stack_transformer( ++ stacked_p, jnp.dtype(self.FPROP_DTYPE)) ++ if issubclass(fdl.get_callable(model_p.lm_tpl.stacked_transformer_tpl), ++ transformers.StackedTransformerRepeated): ++ model_p.lm_tpl.stacked_transformer_tpl.block = stacked_p ++ else: ++ model_p.lm_tpl.stacked_transformer_tpl = stacked_p ++ + + model_p.params_init = WeightInit.Gaussian(self.INIT_STD) + softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) +@@ -172,7 +184,6 @@ class GPT126M(TransformerLmSpmdAdam): + + return task_p + +- + @experiment_registry.register + class Pile126M(GPT126M, PileUnsupervisedDataset): + +@@ -180,11 +191,11 @@ class Pile126M(GPT126M, PileUnsupervisedDataset): + task_p = super().task() + return task_p + +- + @experiment_registry.register + class Lambada126M(GPT126M, LambadaDataset): + + ICI_MESH_SHAPE = [8,1,1] ++ PACKED_INPUT=False + + def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: + task_p = super().task() +@@ -239,7 +250,6 @@ class GPT5B(Pile126M): + return task_p + + +-## 96 node + @experiment_registry.register + class GPT175B(Pile126M): + +@@ -249,7 +259,7 @@ class GPT175B(Pile126M): + # Known as MLP_DIM in t5x + HIDDEN_DIMS = MODEL_DIMS * 4 + # Defaults to MODEL_DIMS // NUM_HEADS. +- DIMS_PER_HEAD = None ++ DIMS_PER_HEAD = 128 + # Known as NUM_EMBEDDINGS in t5x + VOCAB_SIZE = 50257 + USE_REPEATED_LAYER = True +diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py +new file mode 100644 +index 0000000..d44ca67 +--- /dev/null ++++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py +@@ -0,0 +1,324 @@ ++import os ++from contextlib import contextmanager ++from typing import Optional, Sequence ++ ++import jax ++import jax.numpy as jnp ++from jax.ad_checkpoint import checkpoint_name ++from praxis import base_layer ++from praxis import pax_fiddle ++from praxis import pytypes ++from praxis.layers import transformers ++from praxis.layers import stochastics ++ ++try: ++ import transformer_engine.jax as te ++ import transformer_engine.jax.flax as te_flax ++ import transformer_engine.jax.praxis as te_praxis ++ from transformer_engine.common import recipe ++ _IS_TRANSFORMER_ENGINE_INSTALLED = True ++ DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME] ++ import praxis.layers.repeats as praxis_repeat ++ # This is to make Repeat module correctly generate collections we need. ++ praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes ++ te.fp8.FP8Helper.FP8_COLLECTION_NAME:0}) ++ ++except ModuleNotFoundError as e: ++ _IS_TRANSFORMER_ENGINE_INSTALLED = False ++ DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST ++ ++ ++LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] ++JTensor = pytypes.JTensor ++ ++class StackedTransformer(transformers.StackedTransformer): ++ """A mirror of StackedTransformer layers in Praxis.""" ++ ++ def setup(self) -> None: ++ ++ assert self.num_layers > 0 ++ assert self.model_dims > 0 ++ assert self.hidden_dims > 0 ++ assert self.num_heads > 0 ++ assert 0.0 <= self.dropout_prob < 1.0 ++ assert 0.0 <= self.input_dropout_prob < 1.0 ++ ++ def _layer_params(i): ++ """Construct i-th layer params.""" ++ if isinstance(self.transformer_layer_params_tpl, Sequence): ++ factor = self.num_layers // len(self.transformer_layer_params_tpl) ++ ii = i // factor ++ p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii]) ++ else: ++ p_i = self._clone_layer_params(self.transformer_layer_params_tpl) ++ p_i.name = f'layer_{i}' ++ ++ p_i.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple()) ++ p_i.layer_type = te_praxis.TransformerLayerType.DECODER if self.use_cross_attention \ ++ else te_praxis.TransformerLayerType.ENCODER ++ p_i.num_attention_heads = self.num_heads ++ p_i.hidden_size = self.model_dims ++ p_i.mlp_hidden_size = self.hidden_dims ++ assert self.dim_per_head == self.model_dims // self.num_heads ++ assert self.packed_input == False ++ assert len(self.moe_layers) == 0 ++ assert self.ngrammer_tpls is None ++ ++ if self.ngrammer_tpls is not None: ++ if self.ngrammer_tpls[i] is not None: ++ p_i.ngrammer_tpl = self.ngrammer_tpls[i] ++ return p_i ++ ++ if isinstance(self.transformer_layer_params_tpl, (list, tuple)): ++ if self.num_layers % len(self.transformer_layer_params_tpl): ++ raise ValueError('num_layers should be divisible by ' ++ 'transformer_layer_params_tpl') ++ ++ layer_params = [_layer_params(i) for i in range(self.num_layers)] ++ self.create_children('x_layers', layer_params) ++ ++ if self.input_dropout_prob > 0.0: ++ self.create_child( ++ 'input_dropout', ++ pax_fiddle.Config( ++ stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob ++ ), ++ ) ++ ++ def __call__(self, ++ inputs: JTensor, ++ paddings: JTensor, ++ segment_mask: Optional[JTensor] = None, ++ cross_inputs: Optional[JTensor] = None, ++ cross_paddings: Optional[JTensor] = None, ++ cross_segment_mask: Optional[JTensor] = None, ++ segment_pos: Optional[JTensor] = None) -> JTensor: ++ ++ if self.packed_input: ++ assert segment_mask is not None ++ ++ if self.use_cross_attention: ++ assert cross_inputs is not None ++ assert cross_paddings is not None ++ if self.packed_input: ++ assert cross_segment_mask is not None ++ ++ attention_mask, cross_attention_mask = transformers.compute_attention_masks_for_fprop( ++ inputs, ++ paddings, ++ self.mask_self_attention, ++ segment_mask, ++ cross_inputs, ++ cross_paddings, ++ cross_segment_mask, ++ fold_padding_with_segment_mask=self.fold_padding_with_segment_mask, ++ ) ++ ++ x_out = inputs ++ if self.input_dropout_prob > 0.0: ++ x_out = self.input_dropout(x_out) ++ ++ attention_mask = 1 - (attention_mask == 0) ++ attention_mask = attention_mask.astype(jnp.uint8) ++ ++ if cross_attention_mask is not None: ++ cross_attention_mask = 1 - (cross_attention_mask == 0) ++ cross_attention_mask = cross_attention_mask.astype(jnp.uint8) ++ ++ for i in range(self.num_layers): ++ x_in = x_out ++ x_out = self.x_layers[i]( ++ inputs=x_in, ++ attention_mask=attention_mask, ++ encoded=cross_inputs, ++ encoder_decoder_mask=cross_attention_mask) ++ x_out = checkpoint_name(x_out, 'transformer_layer_out') ++ return x_out ++ ++ ++class TransformerEngineHelperBase: ++ ++ @staticmethod ++ def get_stack_transformer(stacked_transformer_p, dtype): ++ raise NotImplementedError ++ ++ @staticmethod ++ def update_fp8_metas_if_needed(mdl_vars, grads): ++ raise NotImplementedError ++ ++ @staticmethod ++ def include_fp8_for_grads_if_needed(variables): ++ raise NotImplementedError ++ ++ @staticmethod ++ def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): ++ raise NotImplementedError ++ ++ @staticmethod ++ @contextmanager ++ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): ++ raise NotImplementedError ++ ++ ++class TENotInstalledHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def get_stack_transformer(stacked_transformer_p, dtype): ++ return stacked_transformer_p ++ ++ @staticmethod ++ def update_fp8_metas_if_needed(mdl_vars, grads): ++ return mdl_vars ++ ++ @staticmethod ++ def include_fp8_for_grads_if_needed(variables): ++ return variables ++ ++ @staticmethod ++ def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): ++ return grads ++ ++ @staticmethod ++ @contextmanager ++ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): ++ try: ++ yield ++ finally: ++ pass ++ ++ ++class TEInstalledHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def get_stack_transformer(stacked_transformer_p, dtype): ++ ++ assert stacked_transformer_p.cls == transformers.StackedTransformer ++ ++ te_stacked_transformer_p = pax_fiddle.Config(StackedTransformer, ++ use_cross_attention=stacked_transformer_p.use_cross_attention, ++ mask_self_attention=stacked_transformer_p.mask_self_attention, ++ num_layers=stacked_transformer_p.num_layers, ++ model_dims=stacked_transformer_p.model_dims, ++ hidden_dims=stacked_transformer_p.hidden_dims, ++ num_heads=stacked_transformer_p.num_heads, ++ dim_per_head=stacked_transformer_p.dim_per_head, ++ dropout_prob=stacked_transformer_p.dropout_prob, ++ atten_dropout_prob=stacked_transformer_p.atten_dropout_prob, ++ residual_dropout_prob=stacked_transformer_p.residual_dropout_prob, ++ relu_dropout_prob=stacked_transformer_p.relu_dropout_prob, ++ residual_droppath_prob=stacked_transformer_p.residual_droppath_prob, ++ input_dropout_prob=stacked_transformer_p.input_dropout_prob, ++ gating_func=stacked_transformer_p.gating_func, ++ unadjusted_expert_capacity_factor=stacked_transformer_p.unadjusted_expert_capacity_factor, ++ packed_input=stacked_transformer_p.packed_input, ++ fold_padding_with_segment_mask=stacked_transformer_p.fold_padding_with_segment_mask, ++ moe_layer_tpl=stacked_transformer_p.moe_layer_tpl, ++ num_experts=stacked_transformer_p.num_experts, ++ num_groups=stacked_transformer_p.num_groups, ++ min_group_size=stacked_transformer_p.min_group_size, ++ moe_layers=stacked_transformer_p.moe_layers, ++ ngrammer_tpls=stacked_transformer_p.ngrammer_tpls ++ ) ++ ++ ori_transformer_engine_p = stacked_transformer_p.transformer_layer_params_tpl ++ ++ te_stacked_transformer_p.transformer_layer_params_tpl = pax_fiddle.Config(te_praxis.TransformerLayer, ++ name='transformer_layer', ++ params_init=stacked_transformer_p.params_init, ++ dtype=dtype, ++ hidden_size=stacked_transformer_p.model_dims, ++ mlp_hidden_size=stacked_transformer_p.hidden_dims, ++ num_attention_heads=stacked_transformer_p.num_heads, ++ layernorm_type='layernorm', ++ layernorm_epsilon=ori_transformer_engine_p.ln_tpl.epsilon, ++ zero_centered_gamma = True, ++ hidden_dropout=ori_transformer_engine_p.residual_dropout_prob, ++ attention_dropout=ori_transformer_engine_p.atten_dropout_prob, ++ mlp_activations=('gelu',), ++ use_bias=True, ++ layer_type=te_praxis.TransformerLayerType.ENCODER, ++ self_attn_mask_type='causal', ++ enable_relative_embedding=False, ++ drop_path=ori_transformer_engine_p.residual_droppath_prob, ++ scaled_query_init=False, ++ scale_attn_logits=True, ++ transpose_batch_sequence=False ++ ) ++ ++ return te_stacked_transformer_p ++ ++ @staticmethod ++ def update_fp8_metas_if_needed(mdl_vars, grads): ++ FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME ++ if FP8_COLLECTION_NAME in grads: ++ mdl_vars[FP8_COLLECTION_NAME] = te.update_fp8_metas(grads)[FP8_COLLECTION_NAME] ++ return mdl_vars ++ ++ @staticmethod ++ def include_fp8_for_grads_if_needed(variables): ++ FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME ++ if FP8_COLLECTION_NAME in variables: ++ variables[FP8_COLLECTION_NAME] = \ ++ jax.tree_util.tree_map(lambda x: False, variables[FP8_COLLECTION_NAME]) ++ return variables ++ ++ @staticmethod ++ def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): ++ FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME ++ if FP8_COLLECTION_NAME in grads: ++ grads[FP8_COLLECTION_NAME] = vars_with_opt[FP8_COLLECTION_NAME].copy() ++ return grads ++ ++ @staticmethod ++ @contextmanager ++ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): ++ fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID, ++ amax_history_len=1024, amax_compute_algo='max') ++ ++ enable_fp8 = bool(int((os.environ.get("ENABLE_FP8", False)))) ++ try: ++ with te.fp8_autocast(enabled=enable_fp8, ++ fp8_recipe=fp8_recipe, ++ sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis)): ++ yield ++ finally: ++ pass ++ ++ ++class TransformerEngineHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def is_enabled_te(): ++ enable_te = bool(int((os.environ.get("ENABLE_TE", False)))) ++ return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te) ++ ++ @staticmethod ++ def get_helper(): ++ if TransformerEngineHelper.is_enabled_te(): ++ return TEInstalledHelper ++ return TENotInstalledHelper ++ ++ @staticmethod ++ def get_stack_transformer(stacked_transformer_p, dtype): ++ return TransformerEngineHelper.get_helper().get_stack_transformer(stacked_transformer_p, dtype) ++ ++ @staticmethod ++ def update_fp8_metas_if_needed(mdl_vars, grads): ++ return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads) ++ ++ @staticmethod ++ def include_fp8_for_grads_if_needed(variables): ++ return TransformerEngineHelper.get_helper().include_fp8_for_grads_if_needed(variables) ++ ++ @staticmethod ++ def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): ++ return TransformerEngineHelper.get_helper().mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt) ++ ++ @staticmethod ++ @contextmanager ++ def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): ++ try: ++ with TransformerEngineHelper.get_helper().fp8_autocast(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis): ++ yield ++ finally: ++ pass +diff --git a/paxml/main.py b/paxml/main.py +index 41b1a9c..c2e866a 100644 +--- a/paxml/main.py ++++ b/paxml/main.py +@@ -50,6 +50,7 @@ from paxml import tf_data_service_lib + from paxml import train + from paxml import trainer_lib + from paxml import tuning_lib ++from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper + from praxis import pax_fiddle + from praxis import py_utils + +@@ -489,39 +490,41 @@ def _main(argv: Sequence[str]) -> None: + FLAGS.host_idx) + ) + +- if FLAGS.exp is not None: +- experiment_config = get_experiment(FLAGS.exp)() +- elif absl_flags.fdl_flags_supplied(): +- # Use the legacy Fiddle flags API to parse command line Fiddle flags. +- cfg = absl_flags.create_buildable_from_flags( +- module=None, allow_imports=True) +- experiment_config = pax_fiddle.build(cfg) +- logging.warning( +- 'Legacy Fiddle flags API usage detected. Please use the new Fiddle' +- ' command line flag `fdl` with various commands to specify the' +- ' config and any overrides. Please see' +- ' `fiddle/docs/flags_code_lab.md` for more' +- ' documentation on Fiddle flags usage.' +- ) +- elif _FIDDLE_CONFIG.value is not None: +- # This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse +- # command line Fiddle flags. See +- # `fiddle/docs/flags_code_lab.md` for details on the new +- # Fiddle flags API. +- logging.info( +- 'Using pax_fiddle_config from the command line: %s', +- _FIDDLE_CONFIG.value, +- ) +- experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value) +- else: +- raise app.UsageError( +- 'No experiment provided. At least one of --exp, --fdl,' +- ' --fdl_config, or --fdl_config_file is required.' +- ) ++ with TransformerEngineHelper.fp8_autocast('replica', 'mdl', 'data'): ++ if FLAGS.exp is not None: ++ experiment_config = get_experiment(FLAGS.exp)() ++ elif absl_flags.fdl_flags_supplied(): ++ # Use the legacy Fiddle flags API to parse command line Fiddle flags. ++ cfg = absl_flags.create_buildable_from_flags( ++ module=None, allow_imports=True) ++ experiment_config = pax_fiddle.build(cfg) ++ logging.warning( ++ 'Legacy Fiddle flags API usage detected. Please use the new Fiddle' ++ ' command line flag `fdl` with various commands to specify the' ++ ' config and any overrides. Please see' ++ ' `fiddle/docs/flags_code_lab.md` for more' ++ ' documentation on Fiddle flags usage.' ++ ) ++ elif _FIDDLE_CONFIG.value is not None: ++ # This uses the new Fiddle flags API `DEFINE_fiddle_config()` to parse ++ # command line Fiddle flags. See ++ # `fiddle/docs/flags_code_lab.md` for details on the new ++ # Fiddle flags API. ++ logging.info( ++ 'Using pax_fiddle_config from the command line: %s', ++ _FIDDLE_CONFIG.value, ++ ) ++ experiment_config = pax_fiddle.build(_FIDDLE_CONFIG.value) ++ else: ++ raise app.UsageError( ++ 'No experiment provided. At least one of --exp, --fdl,' ++ ' --fdl_config, or --fdl_config_file is required.' ++ ) ++ + +- experiment_config.validate() +- run(experiment_config=experiment_config, +- enable_checkpoint_saving=FLAGS.enable_checkpoint_saving) ++ experiment_config.validate() ++ run(experiment_config=experiment_config, ++ enable_checkpoint_saving=FLAGS.enable_checkpoint_saving) + + + _TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+') +diff --git a/paxml/tasks_lib.py b/paxml/tasks_lib.py +index 017d2b3..3dc837d 100644 +--- a/paxml/tasks_lib.py ++++ b/paxml/tasks_lib.py +@@ -43,6 +43,7 @@ from paxml import checkpoint_types + from paxml import io_utils + from paxml import learners as learners_lib + from paxml import train_states ++from paxml.contrib.gpu.scripts_gpu.te_helper import DEFAULT_INIT_MUTABLE_LIST + from praxis import asserts + from praxis import base_hyperparams + from praxis import base_input +@@ -1779,7 +1780,7 @@ class SingleTask(base_task.BaseTask): + inputs_shape_dtype) + # Initialize with a dummy seed + var_weight_hparams = ckpt_task.model.abstract_init_with_metadata( +- inputs_shape_dtype) ++ inputs_shape_dtype, mutable=DEFAULT_INIT_MUTABLE_LIST) + ckpt_train_state = ckpt_task.create_train_state_padded_shapes( + var_weight_hparams) + train_state_pspecs = ckpt_task.create_train_state_partition_specs( +diff --git a/paxml/trainer_lib.py b/paxml/trainer_lib.py +index 587181d..e7fe54a 100644 +--- a/paxml/trainer_lib.py ++++ b/paxml/trainer_lib.py +@@ -35,6 +35,7 @@ from paxml import learners as learners_lib + from paxml import sgf + from paxml import tasks_lib + from paxml import train_states ++from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper, DEFAULT_INIT_MUTABLE_LIST + from praxis import asserts + from praxis import base_hyperparams + from praxis import base_input +@@ -167,8 +168,7 @@ def create_train_state_metadata( + A TrainStateMetadata instance. + """ + var_weight_hparams = jax_task.model.abstract_init_with_metadata( +- train_shape_dtype, do_eval=do_eval +- ) ++ train_shape_dtype, do_eval=do_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) + padded_global_shapes = jax_task.create_train_state_padded_shapes( + var_weight_hparams, discard_opt_states=discard_opt_states + ) +@@ -217,7 +217,8 @@ def write_post_init_model_hparams_file( + logging.info('post_init_model_params: %s', params_fpath) + job_log_dir.mkdir(parents=True, exist_ok=True) + hyper_params = model.abstract_init_with_mdl_config( +- train_state_metadata.input_shape_dtype, do_eval=do_eval ++ train_state_metadata.input_shape_dtype, do_eval=do_eval, ++ extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST + ) + with params_fpath.open('w') as params_file: + hyper_params_dump = base_hyperparams.nested_struct_to_text(hyper_params) +@@ -379,7 +380,8 @@ def initialize_model_state( + is_eval_for_init = is_eval + if not var_weight_hparams: + var_weight_hparams = model.abstract_init_with_metadata( +- inputs_shape_dtype, do_eval=is_eval_for_init ++ inputs_shape_dtype, do_eval=is_eval_for_init, ++ extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST + ) + logging.info('init_var prng_seed: %s', init_key) + logging.info('var_weight_hparams: %s', var_weight_hparams) +@@ -396,7 +398,7 @@ def initialize_model_state( + inputs = jax.tree_map(jnp.zeros_like, inputs_shape_dtype) + if model.hparams.fprop_dtype == jnp.bfloat16: + inputs = jax.tree_map(_maybe_to_bfloat16, inputs) +- return model.init(init_key, inputs) ++ return model.init(init_key, inputs, mutable=DEFAULT_INIT_MUTABLE_LIST) + + initial_vars = init_fn(init_key) + logging.info('initial_vars: %s', jax.tree_map(jnp.shape, initial_vars)) +@@ -809,7 +811,6 @@ class LossFnProtocol(Protocol): + ) -> tuple[JTensor, sgf.GradAuxInfo]: + """Produces losses and grad info by passing the inputs through a model.""" + +- + def _get_default_loss_fn( + jax_task: tasks_lib.SingleTask, + context_p: base_layer.JaxContext.HParams, +@@ -994,14 +995,16 @@ def get_excluded_var_masks( + excluded_for_grad = tasks_lib.get_excluded_var_mask_for_grad( + var_weight_hparams, learner + ) +- _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad) ++ excluded_for_grad_but_fp8_meta = TransformerEngineHelper.include_fp8_for_grads_if_needed(excluded_for_grad.copy()) ++ ++ _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad_but_fp8_meta) + + # Excluded for optimizer states. + excluded_for_opt = tasks_lib.get_excluded_var_mask_for_opt( + var_weight_hparams, + learner, + ) +- return excluded_for_grad, excluded_for_opt ++ return excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt + + + # TODO(yonghui): refactor to pass in learner separately. +@@ -1067,7 +1070,7 @@ def train_step_single_learner( + + if not var_weight_hparams: + with base_layer.JaxContext.new_context(hparams=context_p): +- var_weight_hparams = model.abstract_init_with_metadata(inputs) ++ var_weight_hparams = model.abstract_init_with_metadata(inputs, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) + updated_model_vars = jax_task.maybe_adjust_train_state( # pytype: disable=wrong-arg-types # jax-ndarray + step=states.step, + mdl_vars=states.mdl_vars, +@@ -1077,13 +1080,13 @@ def train_step_single_learner( + + _, subkey = jax.random.split(prng_key) + +- excluded_for_grad, excluded_for_opt = get_excluded_var_masks( ++ excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt = get_excluded_var_masks( + var_weight_hparams, learner + ) + + # Construct and call the grad function. + if not grad_fn: +- grad_fn = _get_default_grad_fn(excluded_for_grad, excluded_for_opt) ++ grad_fn = _get_default_grad_fn(excluded_for_grad_but_fp8_meta, excluded_for_opt) + (weighted_loss, aux_info), grads = grad_fn( + loss_fn=_get_default_loss_fn( + jax_task=jax_task, +@@ -1131,7 +1134,7 @@ def train_step_single_learner( + # Make updated non-trainable vars visible to EMA. + mdl_vars[NON_TRAINABLE] = fwd_updated_vars[NON_TRAINABLE] + excluded_for_learner = jax.tree_map( +- lambda eo, eg: eo and eg, excluded_for_opt, excluded_for_grad ++ lambda eo, eg: eo and eg, excluded_for_opt, excluded_for_grad_but_fp8_meta + ) + vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt( + mdl_vars, excluded_for_learner +@@ -1139,6 +1142,10 @@ def train_step_single_learner( + wps_with_opt = tasks_lib.filter_vars_for_grad_or_opt( + var_weight_hparams, excluded_for_learner + ) ++ ++ mdl_vars = TransformerEngineHelper.update_fp8_metas_if_needed(mdl_vars, grads) ++ grads = TransformerEngineHelper.mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt) ++ + transformed_grads, new_opt_states = learner.update_states( + grads, states.opt_states[0], vars_with_opt, wps_with_opt + ) +@@ -1174,6 +1181,7 @@ def train_step_single_learner( + states.mdl_vars, + mdl_vars, + ) ++ + new_states = states.new_state( + mdl_vars=mdl_vars, opt_states=[new_opt_states], extra_state=() + ) +@@ -1268,7 +1276,7 @@ def eval_step_single_learner( + var_weight_hparams = model.abstract_init_with_metadata( + inputs, + do_eval=not jax_task.hparams.train.always_use_train_for_model_init, +- ) ++ extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) + + if fprop_dtype == jnp.float32: + pass +@@ -1519,7 +1527,7 @@ def initialize_partitioned_model_states( + model = jax_task.model + if not var_weight_hparams: + var_weight_hparams = model.abstract_init_with_metadata( +- global_input_shapes, do_eval=is_eval ++ global_input_shapes, do_eval=is_eval, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST + ) + + train_state_partition_specs = ( +-- +2.25.1 + + +From 9d6b6db6039d7e6658dd179e5838379c7dc967e3 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Wed, 27 Sep 2023 10:46:53 +0800 +Subject: [PATCH 2/9] Adding dropout support when enabling TE. + +--- + paxml/contrib/gpu/scripts_gpu/te_helper.py | 10 ++++++++++ + 1 file changed, 10 insertions(+) + +diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py +index d44ca67..2b9dba4 100644 +--- a/paxml/contrib/gpu/scripts_gpu/te_helper.py ++++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py +@@ -59,6 +59,16 @@ class StackedTransformer(transformers.StackedTransformer): + p_i.num_attention_heads = self.num_heads + p_i.hidden_size = self.model_dims + p_i.mlp_hidden_size = self.hidden_dims ++ ++ p_i.dropout_rng_name = base_layer.RANDOM ++ p_i.attention_dropout = self.atten_dropout_prob or self.dropout_prob ++ p_i.hidden_dropout = self.residual_dropout_prob or self.dropout_prob ++ p_i.intermediate_dropout = self.relu_dropout_prob or self.dropout_prob ++ if self.residual_droppath_prob > 0.0: ++ p_i.drop_path = ( ++ self.residual_droppath_prob * i / max(1, self.num_layers) ++ ) ++ + assert self.dim_per_head == self.model_dims // self.num_heads + assert self.packed_input == False + assert len(self.moe_layers) == 0 +-- +2.25.1 + + +From 1612dc7a1f77f0a515eb4801087a8b4f0756e5b9 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Tue, 24 Oct 2023 10:30:27 +0800 +Subject: [PATCH 3/9] Set deterministic=True for inference. + +--- + paxml/contrib/gpu/scripts_gpu/te_helper.py | 3 ++- + 1 file changed, 2 insertions(+), 1 deletion(-) + +diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py +index 2b9dba4..ef20305 100644 +--- a/paxml/contrib/gpu/scripts_gpu/te_helper.py ++++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py +@@ -141,7 +141,8 @@ class StackedTransformer(transformers.StackedTransformer): + inputs=x_in, + attention_mask=attention_mask, + encoded=cross_inputs, +- encoder_decoder_mask=cross_attention_mask) ++ encoder_decoder_mask=cross_attention_mask, ++ deterministic=self.do_eval) + x_out = checkpoint_name(x_out, 'transformer_layer_out') + return x_out + +-- +2.25.1 + + +From 71507dc4b1396252e6fa746d1299854c204f0c51 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Thu, 2 Nov 2023 22:04:58 -0700 +Subject: [PATCH 4/9] Fix the excluded list for excluded_for_learner + +Signed-off-by: Reese Wang +--- + paxml/trainer_lib.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/paxml/trainer_lib.py b/paxml/trainer_lib.py +index e7fe54a..4093c3b 100644 +--- a/paxml/trainer_lib.py ++++ b/paxml/trainer_lib.py +@@ -1134,7 +1134,7 @@ def train_step_single_learner( + # Make updated non-trainable vars visible to EMA. + mdl_vars[NON_TRAINABLE] = fwd_updated_vars[NON_TRAINABLE] + excluded_for_learner = jax.tree_map( +- lambda eo, eg: eo and eg, excluded_for_opt, excluded_for_grad_but_fp8_meta ++ lambda eo, eg: eo and eg, excluded_for_opt, excluded_for_grad + ) + vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt( + mdl_vars, excluded_for_learner +-- +2.25.1 + + +From 2a8233302c7e42b7dc7628c41abb637518d15c29 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Tue, 7 Nov 2023 11:21:53 +0800 +Subject: [PATCH 5/9] Adapting to TE/JAX/Custom_partitioning. + +--- + paxml/contrib/gpu/scripts_gpu/te_helper.py | 6 ++++-- + 1 file changed, 4 insertions(+), 2 deletions(-) + +diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py +index ef20305..fed1601 100644 +--- a/paxml/contrib/gpu/scripts_gpu/te_helper.py ++++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py +@@ -262,7 +262,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + def update_fp8_metas_if_needed(mdl_vars, grads): + FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME + if FP8_COLLECTION_NAME in grads: +- mdl_vars[FP8_COLLECTION_NAME] = te.update_fp8_metas(grads)[FP8_COLLECTION_NAME] ++ mdl_vars[FP8_COLLECTION_NAME] = grads[FP8_COLLECTION_NAME] + return mdl_vars + + @staticmethod +@@ -290,7 +290,9 @@ class TEInstalledHelper(TransformerEngineHelperBase): + try: + with te.fp8_autocast(enabled=enable_fp8, + fp8_recipe=fp8_recipe, +- sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis, fsdp_mesh_axis)): ++ mesh_resource=te.MeshResource(dp_resource=dp_mesh_axis, ++ tp_resource=tp_mesh_axis, ++ fsdp_resource=fsdp_mesh_axis)): + yield + finally: + pass +-- +2.25.1 + + +From 2a6e5a960f438653b4c9cbeb0c016225af853279 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Tue, 7 Nov 2023 15:14:25 +0800 +Subject: [PATCH 6/9] Adding TE-compatiable PipelinedTransformer + +--- + paxml/contrib/gpu/scripts_gpu/te_helper.py | 109 +++++++++++++++++++++ + 1 file changed, 109 insertions(+) + +diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py +index fed1601..5914e54 100644 +--- a/paxml/contrib/gpu/scripts_gpu/te_helper.py ++++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py +@@ -31,6 +31,7 @@ except ModuleNotFoundError as e: + LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] + JTensor = pytypes.JTensor + ++ + class StackedTransformer(transformers.StackedTransformer): + """A mirror of StackedTransformer layers in Praxis.""" + +@@ -147,12 +148,92 @@ class StackedTransformer(transformers.StackedTransformer): + return x_out + + ++class PipelinedTransformer(transformers.PipelinedTransformer): ++ """A mirror of PipelinedTransformer in Praxis""" ++ ++ def __call__( ++ self, ++ inputs: JTensor, ++ paddings: JTensor, ++ segment_mask: JTensor | None = None, ++ cross_inputs: JTensor | None = None, ++ cross_paddings: JTensor | None = None, ++ cross_segment_mask: JTensor | None = None, ++ segment_pos: JTensor | None = None, ++ ) -> JTensor: ++ ++ rules = te_flax.extend_logical_axis_rules(tuple()) ++ batch_mapping = rules[0] ++ hidden_tp_mapping = rules[4] ++ # [Batch, Seqlen, Hidden] ++ bld_mapping = [batch_mapping, None, hidden_tp_mapping] ++ ++ if not self.stream_io: ++ # Annotate the inputs before the pipeline to prevent unexpected ++ # propagation from earlier layers. ++ inputs = base_layer.maybe_shard(inputs, bld_mapping, self.mesh_axis_names) ++ if bld_mapping is not None: ++ # Annotate other broadcast inputs. ++ paddings = base_layer.maybe_shard( ++ paddings, bld_mapping[:-1], self.mesh_axis_names ++ ) ++ ++ # For cross inputs, we only specify the batch dim sharding. ++ def _shard_batch_dim_only(x): ++ return base_layer.maybe_shard( ++ x, ++ [bld_mapping[0]] + [-1] * (x.ndim - 1), ++ self.mesh_axis_names, ++ unconstrained_dims=range(1, x.ndim), ++ ) ++ ++ if segment_mask is not None: ++ segment_mask = _shard_batch_dim_only(segment_mask) ++ if cross_inputs is not None: ++ cross_inputs = _shard_batch_dim_only(cross_inputs) ++ if cross_paddings is not None: ++ cross_paddings = _shard_batch_dim_only(cross_paddings) ++ if cross_segment_mask is not None: ++ cross_segment_mask = _shard_batch_dim_only(cross_segment_mask) ++ ++ if segment_pos is not None: ++ segment_pos = base_layer.maybe_shard( ++ segment_pos, bld_mapping[:-1], self.mesh_axis_names ++ ) ++ ++ outputs = self.pipeline( ++ inputs, ++ paddings, ++ segment_mask=segment_mask, ++ cross_inputs=cross_inputs, ++ cross_paddings=cross_paddings, ++ cross_segment_mask=cross_segment_mask, ++ segment_pos=segment_pos, ++ ) ++ ++ if not self.stream_io: ++ outputs = base_layer.maybe_shard( ++ outputs, bld_mapping, self.mesh_axis_names ++ ) ++ ++ outputs = base_layer.maybe_shard( ++ outputs, ++ self.activation_split_dims_mapping.final_out, ++ self.mesh_axis_names, ++ ) ++ return outputs ++ ++ + class TransformerEngineHelperBase: + + @staticmethod + def get_stack_transformer(stacked_transformer_p, dtype): + raise NotImplementedError + ++ @staticmethod ++ def get_pipeline_transformer(pipeline_transformer_p): ++ raise NotImplementedError ++ + @staticmethod + def update_fp8_metas_if_needed(mdl_vars, grads): + raise NotImplementedError +@@ -177,6 +258,10 @@ class TENotInstalledHelper(TransformerEngineHelperBase): + def get_stack_transformer(stacked_transformer_p, dtype): + return stacked_transformer_p + ++ @staticmethod ++ def get_pipeline_transformer(pipeline_transformer_p): ++ return pipeline_transformer_p ++ + @staticmethod + def update_fp8_metas_if_needed(mdl_vars, grads): + return mdl_vars +@@ -258,6 +343,26 @@ class TEInstalledHelper(TransformerEngineHelperBase): + + return te_stacked_transformer_p + ++ @staticmethod ++ def get_pipeline_transformer(pipeline_transformer_p): ++ ++ assert pipeline_transformer_p.cls == transformers.PipelinedTransformer ++ ++ te_pipeline_transformer_p = pax_fiddle.Config(PipelinedTransformer, ++ pipeline_stage=pipeline_transformer_p.pipeline_stage, ++ circular_repeat=pipeline_transformer_p.circular_repeat, ++ num_pipeline_stages=pipeline_transformer_p.num_pipeline_stages, ++ num_pipeline_microbatches=pipeline_transformer_p.num_pipeline_microbatches, ++ pipeline_microbatch_size=pipeline_transformer_p.pipeline_microbatch_size, ++ stream_io=pipeline_transformer_p.stream_io, ++ pipeline_broadcast_inputs=pipeline_transformer_p.pipeline_broadcast_inputs, ++ checkpoint_policy=pipeline_transformer_p.checkpoint_policy, ++ enable_async_circular_transfer=pipeline_transformer_p.enable_async_circular_transfer, ++ bf16_accum_in_fp32=pipeline_transformer_p.bf16_accum_in_fp32 ++ ) ++ ++ return te_pipeline_transformer_p ++ + @staticmethod + def update_fp8_metas_if_needed(mdl_vars, grads): + FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME +@@ -315,6 +420,10 @@ class TransformerEngineHelper(TransformerEngineHelperBase): + def get_stack_transformer(stacked_transformer_p, dtype): + return TransformerEngineHelper.get_helper().get_stack_transformer(stacked_transformer_p, dtype) + ++ @staticmethod ++ def get_pipeline_transformer(pipeline_transformer_p): ++ return TransformerEngineHelper.get_helper().get_pipeline_transformer(pipeline_transformer_p) ++ + @staticmethod + def update_fp8_metas_if_needed(mdl_vars, grads): + return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads) +-- +2.25.1 + + +From b57188225e7890dfc54d70db7d89fcb32e61e762 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Wed, 8 Nov 2023 10:06:49 +0800 +Subject: [PATCH 7/9] Apply OWG to TE's FP8 meta + +--- + paxml/contrib/gpu/scripts_gpu/te_helper.py | 59 ---------------------- + paxml/trainer_lib.py | 12 ++--- + 2 files changed, 4 insertions(+), 67 deletions(-) + +diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py +index 5914e54..fd482df 100644 +--- a/paxml/contrib/gpu/scripts_gpu/te_helper.py ++++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py +@@ -2,7 +2,6 @@ import os + from contextlib import contextmanager + from typing import Optional, Sequence + +-import jax + import jax.numpy as jnp + from jax.ad_checkpoint import checkpoint_name + from praxis import base_layer +@@ -234,18 +233,6 @@ class TransformerEngineHelperBase: + def get_pipeline_transformer(pipeline_transformer_p): + raise NotImplementedError + +- @staticmethod +- def update_fp8_metas_if_needed(mdl_vars, grads): +- raise NotImplementedError +- +- @staticmethod +- def include_fp8_for_grads_if_needed(variables): +- raise NotImplementedError +- +- @staticmethod +- def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): +- raise NotImplementedError +- + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): +@@ -262,18 +249,6 @@ class TENotInstalledHelper(TransformerEngineHelperBase): + def get_pipeline_transformer(pipeline_transformer_p): + return pipeline_transformer_p + +- @staticmethod +- def update_fp8_metas_if_needed(mdl_vars, grads): +- return mdl_vars +- +- @staticmethod +- def include_fp8_for_grads_if_needed(variables): +- return variables +- +- @staticmethod +- def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): +- return grads +- + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): +@@ -363,28 +338,6 @@ class TEInstalledHelper(TransformerEngineHelperBase): + + return te_pipeline_transformer_p + +- @staticmethod +- def update_fp8_metas_if_needed(mdl_vars, grads): +- FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME +- if FP8_COLLECTION_NAME in grads: +- mdl_vars[FP8_COLLECTION_NAME] = grads[FP8_COLLECTION_NAME] +- return mdl_vars +- +- @staticmethod +- def include_fp8_for_grads_if_needed(variables): +- FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME +- if FP8_COLLECTION_NAME in variables: +- variables[FP8_COLLECTION_NAME] = \ +- jax.tree_util.tree_map(lambda x: False, variables[FP8_COLLECTION_NAME]) +- return variables +- +- @staticmethod +- def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): +- FP8_COLLECTION_NAME = te.fp8.FP8Helper.FP8_COLLECTION_NAME +- if FP8_COLLECTION_NAME in grads: +- grads[FP8_COLLECTION_NAME] = vars_with_opt[FP8_COLLECTION_NAME].copy() +- return grads +- + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): +@@ -424,18 +377,6 @@ class TransformerEngineHelper(TransformerEngineHelperBase): + def get_pipeline_transformer(pipeline_transformer_p): + return TransformerEngineHelper.get_helper().get_pipeline_transformer(pipeline_transformer_p) + +- @staticmethod +- def update_fp8_metas_if_needed(mdl_vars, grads): +- return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads) +- +- @staticmethod +- def include_fp8_for_grads_if_needed(variables): +- return TransformerEngineHelper.get_helper().include_fp8_for_grads_if_needed(variables) +- +- @staticmethod +- def mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt): +- return TransformerEngineHelper.get_helper().mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt) +- + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): +diff --git a/paxml/trainer_lib.py b/paxml/trainer_lib.py +index 4093c3b..2e8fc35 100644 +--- a/paxml/trainer_lib.py ++++ b/paxml/trainer_lib.py +@@ -995,16 +995,15 @@ def get_excluded_var_masks( + excluded_for_grad = tasks_lib.get_excluded_var_mask_for_grad( + var_weight_hparams, learner + ) +- excluded_for_grad_but_fp8_meta = TransformerEngineHelper.include_fp8_for_grads_if_needed(excluded_for_grad.copy()) + +- _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad_but_fp8_meta) ++ _log_bprop_include_exclude_list(var_weight_hparams, excluded_for_grad) + + # Excluded for optimizer states. + excluded_for_opt = tasks_lib.get_excluded_var_mask_for_opt( + var_weight_hparams, + learner, + ) +- return excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt ++ return excluded_for_grad, excluded_for_opt + + + # TODO(yonghui): refactor to pass in learner separately. +@@ -1080,13 +1079,13 @@ def train_step_single_learner( + + _, subkey = jax.random.split(prng_key) + +- excluded_for_grad, excluded_for_grad_but_fp8_meta, excluded_for_opt = get_excluded_var_masks( ++ excluded_for_grad, excluded_for_opt = get_excluded_var_masks( + var_weight_hparams, learner + ) + + # Construct and call the grad function. + if not grad_fn: +- grad_fn = _get_default_grad_fn(excluded_for_grad_but_fp8_meta, excluded_for_opt) ++ grad_fn = _get_default_grad_fn(excluded_for_grad, excluded_for_opt) + (weighted_loss, aux_info), grads = grad_fn( + loss_fn=_get_default_loss_fn( + jax_task=jax_task, +@@ -1143,9 +1142,6 @@ def train_step_single_learner( + var_weight_hparams, excluded_for_learner + ) + +- mdl_vars = TransformerEngineHelper.update_fp8_metas_if_needed(mdl_vars, grads) +- grads = TransformerEngineHelper.mask_out_fp8_meta_grads_if_needed(grads, vars_with_opt) +- + transformed_grads, new_opt_states = learner.update_states( + grads, states.opt_states[0], vars_with_opt, wps_with_opt + ) +-- +2.25.1 + + +From c43766ee2e8cda686176a3895e87150b10d5de5e Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Wed, 15 Nov 2023 14:43:17 +0800 +Subject: [PATCH 8/9] Remove Praxis related setup (Moving to Praxis TE/Patch) + +--- + paxml/contrib/gpu/scripts_gpu/configs.py | 9 - + paxml/contrib/gpu/scripts_gpu/te_helper.py | 315 --------------------- + 2 files changed, 324 deletions(-) + +diff --git a/paxml/contrib/gpu/scripts_gpu/configs.py b/paxml/contrib/gpu/scripts_gpu/configs.py +index 50b848e..0a31555 100644 +--- a/paxml/contrib/gpu/scripts_gpu/configs.py ++++ b/paxml/contrib/gpu/scripts_gpu/configs.py +@@ -166,15 +166,6 @@ class GPT126M(TransformerLmSpmdAdam): + transformer_layer_p = stacked_p.transformer_layer_params_tpl + transformer_layer_p.ln_tpl.reductions_in_fp32 = True + transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True +- else: +- stacked_p = TransformerEngineHelper.get_stack_transformer( +- stacked_p, jnp.dtype(self.FPROP_DTYPE)) +- if issubclass(fdl.get_callable(model_p.lm_tpl.stacked_transformer_tpl), +- transformers.StackedTransformerRepeated): +- model_p.lm_tpl.stacked_transformer_tpl.block = stacked_p +- else: +- model_p.lm_tpl.stacked_transformer_tpl = stacked_p +- + + model_p.params_init = WeightInit.Gaussian(self.INIT_STD) + softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) +diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py +index fd482df..b271258 100644 +--- a/paxml/contrib/gpu/scripts_gpu/te_helper.py ++++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py +@@ -1,238 +1,17 @@ + import os + from contextlib import contextmanager +-from typing import Optional, Sequence +- +-import jax.numpy as jnp +-from jax.ad_checkpoint import checkpoint_name +-from praxis import base_layer +-from praxis import pax_fiddle +-from praxis import pytypes +-from praxis.layers import transformers +-from praxis.layers import stochastics + + try: + import transformer_engine.jax as te +- import transformer_engine.jax.flax as te_flax +- import transformer_engine.jax.praxis as te_praxis + from transformer_engine.common import recipe + _IS_TRANSFORMER_ENGINE_INSTALLED = True +- DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME] +- import praxis.layers.repeats as praxis_repeat +- # This is to make Repeat module correctly generate collections we need. +- praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes +- te.fp8.FP8Helper.FP8_COLLECTION_NAME:0}) + + except ModuleNotFoundError as e: + _IS_TRANSFORMER_ENGINE_INSTALLED = False +- DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST +- +- +-LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] +-JTensor = pytypes.JTensor +- +- +-class StackedTransformer(transformers.StackedTransformer): +- """A mirror of StackedTransformer layers in Praxis.""" +- +- def setup(self) -> None: +- +- assert self.num_layers > 0 +- assert self.model_dims > 0 +- assert self.hidden_dims > 0 +- assert self.num_heads > 0 +- assert 0.0 <= self.dropout_prob < 1.0 +- assert 0.0 <= self.input_dropout_prob < 1.0 +- +- def _layer_params(i): +- """Construct i-th layer params.""" +- if isinstance(self.transformer_layer_params_tpl, Sequence): +- factor = self.num_layers // len(self.transformer_layer_params_tpl) +- ii = i // factor +- p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii]) +- else: +- p_i = self._clone_layer_params(self.transformer_layer_params_tpl) +- p_i.name = f'layer_{i}' +- +- p_i.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple()) +- p_i.layer_type = te_praxis.TransformerLayerType.DECODER if self.use_cross_attention \ +- else te_praxis.TransformerLayerType.ENCODER +- p_i.num_attention_heads = self.num_heads +- p_i.hidden_size = self.model_dims +- p_i.mlp_hidden_size = self.hidden_dims +- +- p_i.dropout_rng_name = base_layer.RANDOM +- p_i.attention_dropout = self.atten_dropout_prob or self.dropout_prob +- p_i.hidden_dropout = self.residual_dropout_prob or self.dropout_prob +- p_i.intermediate_dropout = self.relu_dropout_prob or self.dropout_prob +- if self.residual_droppath_prob > 0.0: +- p_i.drop_path = ( +- self.residual_droppath_prob * i / max(1, self.num_layers) +- ) +- +- assert self.dim_per_head == self.model_dims // self.num_heads +- assert self.packed_input == False +- assert len(self.moe_layers) == 0 +- assert self.ngrammer_tpls is None +- +- if self.ngrammer_tpls is not None: +- if self.ngrammer_tpls[i] is not None: +- p_i.ngrammer_tpl = self.ngrammer_tpls[i] +- return p_i +- +- if isinstance(self.transformer_layer_params_tpl, (list, tuple)): +- if self.num_layers % len(self.transformer_layer_params_tpl): +- raise ValueError('num_layers should be divisible by ' +- 'transformer_layer_params_tpl') +- +- layer_params = [_layer_params(i) for i in range(self.num_layers)] +- self.create_children('x_layers', layer_params) +- +- if self.input_dropout_prob > 0.0: +- self.create_child( +- 'input_dropout', +- pax_fiddle.Config( +- stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob +- ), +- ) +- +- def __call__(self, +- inputs: JTensor, +- paddings: JTensor, +- segment_mask: Optional[JTensor] = None, +- cross_inputs: Optional[JTensor] = None, +- cross_paddings: Optional[JTensor] = None, +- cross_segment_mask: Optional[JTensor] = None, +- segment_pos: Optional[JTensor] = None) -> JTensor: +- +- if self.packed_input: +- assert segment_mask is not None +- +- if self.use_cross_attention: +- assert cross_inputs is not None +- assert cross_paddings is not None +- if self.packed_input: +- assert cross_segment_mask is not None +- +- attention_mask, cross_attention_mask = transformers.compute_attention_masks_for_fprop( +- inputs, +- paddings, +- self.mask_self_attention, +- segment_mask, +- cross_inputs, +- cross_paddings, +- cross_segment_mask, +- fold_padding_with_segment_mask=self.fold_padding_with_segment_mask, +- ) +- +- x_out = inputs +- if self.input_dropout_prob > 0.0: +- x_out = self.input_dropout(x_out) +- +- attention_mask = 1 - (attention_mask == 0) +- attention_mask = attention_mask.astype(jnp.uint8) +- +- if cross_attention_mask is not None: +- cross_attention_mask = 1 - (cross_attention_mask == 0) +- cross_attention_mask = cross_attention_mask.astype(jnp.uint8) +- +- for i in range(self.num_layers): +- x_in = x_out +- x_out = self.x_layers[i]( +- inputs=x_in, +- attention_mask=attention_mask, +- encoded=cross_inputs, +- encoder_decoder_mask=cross_attention_mask, +- deterministic=self.do_eval) +- x_out = checkpoint_name(x_out, 'transformer_layer_out') +- return x_out +- +- +-class PipelinedTransformer(transformers.PipelinedTransformer): +- """A mirror of PipelinedTransformer in Praxis""" +- +- def __call__( +- self, +- inputs: JTensor, +- paddings: JTensor, +- segment_mask: JTensor | None = None, +- cross_inputs: JTensor | None = None, +- cross_paddings: JTensor | None = None, +- cross_segment_mask: JTensor | None = None, +- segment_pos: JTensor | None = None, +- ) -> JTensor: +- +- rules = te_flax.extend_logical_axis_rules(tuple()) +- batch_mapping = rules[0] +- hidden_tp_mapping = rules[4] +- # [Batch, Seqlen, Hidden] +- bld_mapping = [batch_mapping, None, hidden_tp_mapping] +- +- if not self.stream_io: +- # Annotate the inputs before the pipeline to prevent unexpected +- # propagation from earlier layers. +- inputs = base_layer.maybe_shard(inputs, bld_mapping, self.mesh_axis_names) +- if bld_mapping is not None: +- # Annotate other broadcast inputs. +- paddings = base_layer.maybe_shard( +- paddings, bld_mapping[:-1], self.mesh_axis_names +- ) +- +- # For cross inputs, we only specify the batch dim sharding. +- def _shard_batch_dim_only(x): +- return base_layer.maybe_shard( +- x, +- [bld_mapping[0]] + [-1] * (x.ndim - 1), +- self.mesh_axis_names, +- unconstrained_dims=range(1, x.ndim), +- ) +- +- if segment_mask is not None: +- segment_mask = _shard_batch_dim_only(segment_mask) +- if cross_inputs is not None: +- cross_inputs = _shard_batch_dim_only(cross_inputs) +- if cross_paddings is not None: +- cross_paddings = _shard_batch_dim_only(cross_paddings) +- if cross_segment_mask is not None: +- cross_segment_mask = _shard_batch_dim_only(cross_segment_mask) +- +- if segment_pos is not None: +- segment_pos = base_layer.maybe_shard( +- segment_pos, bld_mapping[:-1], self.mesh_axis_names +- ) +- +- outputs = self.pipeline( +- inputs, +- paddings, +- segment_mask=segment_mask, +- cross_inputs=cross_inputs, +- cross_paddings=cross_paddings, +- cross_segment_mask=cross_segment_mask, +- segment_pos=segment_pos, +- ) +- +- if not self.stream_io: +- outputs = base_layer.maybe_shard( +- outputs, bld_mapping, self.mesh_axis_names +- ) +- +- outputs = base_layer.maybe_shard( +- outputs, +- self.activation_split_dims_mapping.final_out, +- self.mesh_axis_names, +- ) +- return outputs + + + class TransformerEngineHelperBase: + +- @staticmethod +- def get_stack_transformer(stacked_transformer_p, dtype): +- raise NotImplementedError +- +- @staticmethod +- def get_pipeline_transformer(pipeline_transformer_p): +- raise NotImplementedError +- + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): +@@ -241,14 +20,6 @@ class TransformerEngineHelperBase: + + class TENotInstalledHelper(TransformerEngineHelperBase): + +- @staticmethod +- def get_stack_transformer(stacked_transformer_p, dtype): +- return stacked_transformer_p +- +- @staticmethod +- def get_pipeline_transformer(pipeline_transformer_p): +- return pipeline_transformer_p +- + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): +@@ -260,84 +31,6 @@ class TENotInstalledHelper(TransformerEngineHelperBase): + + class TEInstalledHelper(TransformerEngineHelperBase): + +- @staticmethod +- def get_stack_transformer(stacked_transformer_p, dtype): +- +- assert stacked_transformer_p.cls == transformers.StackedTransformer +- +- te_stacked_transformer_p = pax_fiddle.Config(StackedTransformer, +- use_cross_attention=stacked_transformer_p.use_cross_attention, +- mask_self_attention=stacked_transformer_p.mask_self_attention, +- num_layers=stacked_transformer_p.num_layers, +- model_dims=stacked_transformer_p.model_dims, +- hidden_dims=stacked_transformer_p.hidden_dims, +- num_heads=stacked_transformer_p.num_heads, +- dim_per_head=stacked_transformer_p.dim_per_head, +- dropout_prob=stacked_transformer_p.dropout_prob, +- atten_dropout_prob=stacked_transformer_p.atten_dropout_prob, +- residual_dropout_prob=stacked_transformer_p.residual_dropout_prob, +- relu_dropout_prob=stacked_transformer_p.relu_dropout_prob, +- residual_droppath_prob=stacked_transformer_p.residual_droppath_prob, +- input_dropout_prob=stacked_transformer_p.input_dropout_prob, +- gating_func=stacked_transformer_p.gating_func, +- unadjusted_expert_capacity_factor=stacked_transformer_p.unadjusted_expert_capacity_factor, +- packed_input=stacked_transformer_p.packed_input, +- fold_padding_with_segment_mask=stacked_transformer_p.fold_padding_with_segment_mask, +- moe_layer_tpl=stacked_transformer_p.moe_layer_tpl, +- num_experts=stacked_transformer_p.num_experts, +- num_groups=stacked_transformer_p.num_groups, +- min_group_size=stacked_transformer_p.min_group_size, +- moe_layers=stacked_transformer_p.moe_layers, +- ngrammer_tpls=stacked_transformer_p.ngrammer_tpls +- ) +- +- ori_transformer_engine_p = stacked_transformer_p.transformer_layer_params_tpl +- +- te_stacked_transformer_p.transformer_layer_params_tpl = pax_fiddle.Config(te_praxis.TransformerLayer, +- name='transformer_layer', +- params_init=stacked_transformer_p.params_init, +- dtype=dtype, +- hidden_size=stacked_transformer_p.model_dims, +- mlp_hidden_size=stacked_transformer_p.hidden_dims, +- num_attention_heads=stacked_transformer_p.num_heads, +- layernorm_type='layernorm', +- layernorm_epsilon=ori_transformer_engine_p.ln_tpl.epsilon, +- zero_centered_gamma = True, +- hidden_dropout=ori_transformer_engine_p.residual_dropout_prob, +- attention_dropout=ori_transformer_engine_p.atten_dropout_prob, +- mlp_activations=('gelu',), +- use_bias=True, +- layer_type=te_praxis.TransformerLayerType.ENCODER, +- self_attn_mask_type='causal', +- enable_relative_embedding=False, +- drop_path=ori_transformer_engine_p.residual_droppath_prob, +- scaled_query_init=False, +- scale_attn_logits=True, +- transpose_batch_sequence=False +- ) +- +- return te_stacked_transformer_p +- +- @staticmethod +- def get_pipeline_transformer(pipeline_transformer_p): +- +- assert pipeline_transformer_p.cls == transformers.PipelinedTransformer +- +- te_pipeline_transformer_p = pax_fiddle.Config(PipelinedTransformer, +- pipeline_stage=pipeline_transformer_p.pipeline_stage, +- circular_repeat=pipeline_transformer_p.circular_repeat, +- num_pipeline_stages=pipeline_transformer_p.num_pipeline_stages, +- num_pipeline_microbatches=pipeline_transformer_p.num_pipeline_microbatches, +- pipeline_microbatch_size=pipeline_transformer_p.pipeline_microbatch_size, +- stream_io=pipeline_transformer_p.stream_io, +- pipeline_broadcast_inputs=pipeline_transformer_p.pipeline_broadcast_inputs, +- checkpoint_policy=pipeline_transformer_p.checkpoint_policy, +- enable_async_circular_transfer=pipeline_transformer_p.enable_async_circular_transfer, +- bf16_accum_in_fp32=pipeline_transformer_p.bf16_accum_in_fp32 +- ) +- +- return te_pipeline_transformer_p +- + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): +@@ -369,14 +62,6 @@ class TransformerEngineHelper(TransformerEngineHelperBase): + return TEInstalledHelper + return TENotInstalledHelper + +- @staticmethod +- def get_stack_transformer(stacked_transformer_p, dtype): +- return TransformerEngineHelper.get_helper().get_stack_transformer(stacked_transformer_p, dtype) +- +- @staticmethod +- def get_pipeline_transformer(pipeline_transformer_p): +- return TransformerEngineHelper.get_helper().get_pipeline_transformer(pipeline_transformer_p) +- + @staticmethod + @contextmanager + def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): +-- +2.25.1 + + +From abc0fabc3e2ffb42d1f62254ad42448a39cbd128 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Wed, 15 Nov 2023 14:51:14 +0800 +Subject: [PATCH 9/9] Fix missing DEFAULT_INIT_MUTABLE_LIST + +--- + paxml/contrib/gpu/scripts_gpu/te_helper.py | 4 ++++ + 1 file changed, 4 insertions(+) + +diff --git a/paxml/contrib/gpu/scripts_gpu/te_helper.py b/paxml/contrib/gpu/scripts_gpu/te_helper.py +index b271258..cbac7cf 100644 +--- a/paxml/contrib/gpu/scripts_gpu/te_helper.py ++++ b/paxml/contrib/gpu/scripts_gpu/te_helper.py +@@ -1,13 +1,17 @@ + import os + from contextlib import contextmanager + ++from praxis import base_layer ++ + try: + import transformer_engine.jax as te + from transformer_engine.common import recipe + _IS_TRANSFORMER_ENGINE_INSTALLED = True ++ DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + [te.fp8.FP8Helper.FP8_COLLECTION_NAME] + + except ModuleNotFoundError as e: + _IS_TRANSFORMER_ENGINE_INSTALLED = False ++ DEFAULT_INIT_MUTABLE_LIST = base_layer.DEFAULT_INIT_MUTABLE_LIST + + + class TransformerEngineHelperBase: +-- +2.25.1 + diff --git a/.github/container/patches/praxis/PR-27.patch b/.github/container/patches/praxis/PR-27.patch new file mode 100644 index 000000000..516d2fe74 --- /dev/null +++ b/.github/container/patches/praxis/PR-27.patch @@ -0,0 +1,38 @@ +From 9ac9259907137d7cc2edd0d1ac3fd01dbf27801d Mon Sep 17 00:00:00 2001 +From: ashors1 +Date: Mon, 18 Sep 2023 17:40:53 -0700 +Subject: [PATCH] Add alternate method to apply mask to allow XLA to detect MHA + pattern + +--- + praxis/layers/attentions.py | 7 ++++++- + 1 file changed, 6 insertions(+), 1 deletion(-) + +diff --git a/praxis/layers/attentions.py b/praxis/layers/attentions.py +index a35ce8b..52886bc 100644 +--- a/praxis/layers/attentions.py ++++ b/praxis/layers/attentions.py +@@ -1173,6 +1173,7 @@ class DotProductAttention(base_layer.BaseLayer): + decode_cache: bool = True + attention_mask_summary: bool = False + zero_fully_masked: bool = False ++ mha_mask_addition_pattern: bool = True + qk_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) + pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) + per_dim_scale_tpl: LayerTpl = template_field(PerDimScale) +@@ -1524,7 +1525,11 @@ class DotProductAttention(base_layer.BaseLayer): + # Attention softmax is always carried out in fp32. + logits = logits.astype(jnp.float32) + # Apply attention masking +- padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) ++ if self.mha_mask_addition_pattern: ++ padded_logits = logits + atten_mask.astype(jnp.float32) ++ else: ++ padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) ++ + if self.attention_mask_summary: + self.add_summary('attention_mask', atten_mask) + if self.attention_extra_logit is None: +-- +2.25.1 + diff --git a/.github/container/patches/t5x/mirror-patch-dali-support.patch b/.github/container/patches/t5x/mirror-patch-dali-support.patch new file mode 100644 index 000000000..7c8b50aa3 --- /dev/null +++ b/.github/container/patches/t5x/mirror-patch-dali-support.patch @@ -0,0 +1,387 @@ +From 29cb12a419a5897d735073ef11f5d5da12021cd6 Mon Sep 17 00:00:00 2001 +From: ashors1 +Date: Tue, 16 May 2023 11:53:31 -0700 +Subject: [PATCH 1/2] add support for DALI datasets + +--- + t5x/train.py | 86 ++++++++++++++++++++++++----- + t5x/trainer.py | 146 +++++++++++++++++++++++++++++++++++++++++++++++++ + 2 files changed, 217 insertions(+), 15 deletions(-) + +diff --git a/t5x/train.py b/t5x/train.py +index 06cd0a3..c0263ca 100644 +--- a/t5x/train.py ++++ b/t5x/train.py +@@ -116,10 +116,13 @@ def train( + ], + inference_evaluator_cls: utils.EvaluatorConstructor = seqio.Evaluator, + get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset, ++ prepare_train_iter_fn: Optional[Callable] = utils.prepare_train_iter, + concurrent_metrics: bool = True, + actions: Optional[Mapping[str, Sequence[trainer_lib.BaseAction]]] = None, + train_eval_get_dataset_fn: utils.GetEvalDatasetCallable = utils.get_training_eval_datasets, ++ prepare_eval_iter_fn: Optional[Callable] = None, + run_eval_before_training: bool = False, ++ run_dali_eval: bool = False, + train_state_initializer_cls: Type[ + utils.TrainStateInitializer + ] = utils.TrainStateInitializer, +@@ -169,6 +172,8 @@ def train( + evaluation, potentially with bound configuration args. + get_dataset_fn: The callable use to get the train and train-eval datasets + based on the DatasetConfig and shard information. ++ prepare_train_iter_fn: An optional function that prepares the training input ++ iterator. Defaults to `utils.prepare_train_iter`. + concurrent_metrics: If True, allow metrics computation and logging to + overlap with training. Will likely result in additional TPU memory usage. + actions: A mapping of actions that runs after train, eval or infer_eval, to +@@ -179,8 +184,11 @@ def train( + train_eval_get_dataset_fn: Optional callable use to get the train-eval + datasets based on the DatasetConfig and shard information. If missing, it + defaults to `utils.get_training_eval_datasets`. ++ prepare_eval_iter_fn: An optional function that prepares the eval input ++ iterators. + run_eval_before_training: If True, calculate training eval and inference + eval metrics before training begins. ++ run_dali_eval: Whether to run interleaved evaluation on a DALI dataset. + train_state_initializer_cls: t5x.utils.TrainStateInitializer class for + initializing partitioned TrainState from checkpoints or scratch. + use_orbax: if True, uses Orbax for checkpointing. Experimental feature. +@@ -300,12 +308,15 @@ def train( + train_iter = get_dataset_fn( + train_dataset_cfg, ds_shard_id, num_ds_shards, model.FEATURE_CONVERTER_CLS + ) +- train_iter = utils.prepare_train_iter( +- train_iter, +- checkpoint_cfg=checkpoint_cfg, +- partitioner=partitioner, +- data_layout=data_layout, +- ) ++ ++ if prepare_train_iter_fn: ++ train_iter = utils.prepare_train_iter( ++ train_iter, ++ checkpoint_cfg=checkpoint_cfg, ++ partitioner=partitioner, ++ data_layout=data_layout, ++ ) ++ + input_shapes = jax.tree_map( + lambda x: (data_layout.batch_size, *x.shape[1:]), + train_iter.element_spec, +@@ -321,6 +332,12 @@ def train( + eval_steps, + model.FEATURE_CONVERTER_CLS, + ) # type: Mapping[str, tf.data.Dataset] ++ if prepare_eval_iter_fn: ++ for k in train_eval_datasets: ++ train_eval_datasets[k] = prepare_eval_iter_fn(train_eval_datasets[k], ++ checkpoint_cfg=checkpoint_cfg, ++ partitioner=partitioner, ++ data_layout=data_layout) + if not train_eval_datasets: + logging.warning( + 'No train_eval datasets loaded from config `train_eval_dataset_cfg`: ' +@@ -507,12 +524,19 @@ def train( + def _run_training_eval(first_run: bool = False): + if first_run: + logging.info('Compiling training eval loop.') +- trainer.compile_eval( ++ if run_dali_eval: ++ trainer.compile_eval_dali({ ++ task: utils.get_zeros_batch_like_dataset(ds) ++ for task, ds in train_eval_datasets.items()}, ++ jnp.ones((train_eval_dataset_cfg.batch_size,)).astype(jnp.bool_) ++ ) ++ else: ++ trainer.compile_eval( + { # pytype: disable=wrong-arg-types # jax-ndarray + task: utils.get_zeros_batch_like_dataset(ds) + for task, ds in train_eval_datasets.items() + } +- ) ++ ) + logging.info('Computing training evaluation metrics.') + eval_batch_iters = {} + for task, ds in train_eval_datasets.items(): +@@ -521,13 +545,20 @@ def train( + else: + eval_batch_iters[task] = ds + +- eval_summaries = trainer.eval(eval_batch_iters) +- trainer.stop_training = run_actions( +- trainer_lib.ActionMode.TRAIN_EVAL, # pytype: disable=wrong-arg-types # jax-ndarray +- actions, +- trainer.train_state, +- eval_summaries, +- ) ++ if run_dali_eval: ++ eval_summaries = trainer.eval_dali(eval_batch_iters, ++ train_eval_dataset_cfg.batch_size, ++ ds_shard_id, ++ num_ds_shards, ++ eval_steps) ++ else: ++ eval_summaries = trainer.eval(eval_batch_iters) ++ trainer.stop_training = run_actions( ++ trainer_lib.ActionMode.TRAIN_EVAL, # pytype: disable=wrong-arg-types # jax-ndarray ++ actions, ++ trainer.train_state, ++ eval_summaries, ++ ) + + def _run_inference_eval(): + """Run prediction based inference eval.""" +@@ -556,6 +587,19 @@ def train( + if train_eval_datasets: + logging.info('Running training eval before training.') + _run_training_eval(first_run=True) ++ if run_dali_eval: ++ ## reset eval dataset ++ train_eval_datasets = train_eval_get_dataset_fn( ++ train_eval_dataset_cfg, ds_shard_id, num_ds_shards, eval_steps, ++ model.FEATURE_CONVERTER_CLS) ++ if prepare_eval_iter_fn: ++ for k in train_eval_datasets: ++ train_eval_datasets[k] = prepare_eval_iter_fn(train_eval_datasets[k], ++ checkpoint_cfg=checkpoint_cfg, ++ partitioner=partitioner, ++ data_layout=data_layout) ++ ++ + if evaluator is not None: + logging.info('Running inference eval before training.') + _run_inference_eval() +@@ -792,6 +836,18 @@ def train( + # Maybe less if final step < period. + first_run = step_offset // eval_period <= 1 + _run_training_eval(first_run and not run_eval_before_training) ++ if run_dali_eval: ++ ## reset eval dataset ++ train_eval_datasets = train_eval_get_dataset_fn( ++ train_eval_dataset_cfg, ds_shard_id, num_ds_shards, eval_steps, ++ model.FEATURE_CONVERTER_CLS) ++ if prepare_eval_iter_fn: ++ for k in train_eval_datasets: ++ train_eval_datasets[k] = prepare_eval_iter_fn(train_eval_datasets[k], ++ checkpoint_cfg=checkpoint_cfg, ++ partitioner=partitioner, ++ data_layout=data_layout) ++ + + # Inference Evaluation (i.e., with decoding or scoring). + if is_eval_epoch and evaluator is not None: +diff --git a/t5x/trainer.py b/t5x/trainer.py +index 752beb3..5475376 100644 +--- a/t5x/trainer.py ++++ b/t5x/trainer.py +@@ -34,6 +34,7 @@ import clu.metrics + import clu.values + from flax.core import FrozenDict + import jax ++from jax.experimental import multihost_utils + import jax.lax + import jax.numpy as jnp + import jax.random +@@ -596,6 +597,99 @@ class BaseTrainer(abc.ABC): + # TODO(adarob): Return futures. + return {k: v.result() for k, v in eval_summaries.items()} + ++ def eval_dali( ++ self, batch_iters: Mapping[str, ++ Iterator[BatchType]], ++ batch_size: int, ++ shard_id: int, ++ num_shards: int, ++ eval_steps: Optional[int] = None) -> Mapping[str, Array]: ++ """For DALI datasets, runs evaluation loop over the iterator and prints summary.""" ++ ++ def _remove_padding_eval(all_evaluations, all_nonpaddings): ++ """Remove padded examples.""" ++ ++ for k in all_evaluations: ++ all_evaluations[k] = all_evaluations[k][all_nonpaddings] ++ return all_evaluations ++ ++ eval_summaries = {} ++ train_state = self.train_state ++ ++ for iter_name, ds in batch_iters.items(): ++ logging.info("Evaluating: %s.", iter_name) ++ metrics = None ++ # Use a pre-compiled step function, if available. ++ eval_step_fn = self._compiled_eval_steps.get(iter_name, ++ self._partitioned_eval_step) ++ ++ mm = self.eval_metrics_managers[iter_name] ++ nonpaddings = [] ++ metrics = {} ++ last_source = None ++ batches_per_shard = None ++ ++ num_steps = 0 ++ mm.start_duration_timer(block_on=train_state) ++ ++ for batch in ds: ++ num_steps += 1 ++ ++ utils.multihost_assert_equal( ++ jnp.array(num_steps), ++ "Eval step mismatch across hosts. Check for empty dataset shard.") ++ ++ batch_nonpadding = ds.is_nonpadding ++ ++ if jax.process_count() > 1: ++ batch, batch_nonpadding = partitioning.host_local_array_to_global_array( ++ (batch, batch_nonpadding), self._partitioner.mesh, ++ self._partitioner.data_partition_spec) ++ ++ metrics_update, batch_nonpadding = eval_step_fn(train_state, batch, batch_nonpadding) ++ ++ metrics_update, batch_nonpadding = ( ++ multihost_utils.global_array_to_host_local_array( ++ (metrics_update, batch_nonpadding), self._partitioner.mesh, ++ (jax.sharding.PartitionSpec(), jax.sharding.PartitionSpec()) ++ ) ++ ) ++ ++ for k in metrics_update: ++ if k in metrics: ++ metrics[k] = np.concatenate((metrics[k], metrics_update[k])) ++ else: ++ metrics[k] = np.array(metrics_update[k]) ++ ++ if len(nonpaddings) == 0: ++ nonpaddings = batch_nonpadding ++ else: ++ nonpaddings = np.concatenate([nonpaddings, batch_nonpadding], axis=None) ++ ++ if batches_per_shard is None: ++ meta = ds._iterator.wrapped_pipeline.meta ++ batches_per_shard = meta['epoch_size_padded'] / len(batch_nonpadding) ++ ++ utils.multihost_assert_equal( ++ jnp.array(-1), ++ "Eval step mismatch across hosts. Check for empty dataset shard.") ++ ++ ## indicates that we have reached the end of eval ++ if (eval_steps and num_steps >= eval_steps) or (num_steps >= batches_per_shard): ++ break ++ ++ metrics = _remove_padding_eval(metrics, nonpaddings) ++ metrics = {k: jnp.mean(metrics[k]) for k in metrics.keys()} ++ eval_summaries[iter_name] = metrics ++ ++ clu_metrics = {k: clu.metrics.Average.from_model_output(jnp.asarray([metrics[k]])) for k in metrics} ++ eval_summaries[iter_name] = mm.write_metrics_summary( # pytype: disable=wrong-arg-types # jax-ndarray ++ clu_metrics, train_state.step, num_steps) ++ ++ ++ logging.info(f'Completed eval. Metrics: {metrics}') ++ return {k: v.result() for k, v in eval_summaries.items()} ++ + def compile_eval(self, batches: Mapping[str, BatchType]) -> None: + """Pre-compiles eval step (if not yet compiled). + +@@ -633,6 +727,44 @@ class BaseTrainer(abc.ABC): + self.eval_metrics_managers[eval_name].write_scalar( # pytype: disable=wrong-arg-types # jax-ndarray + "timing/compilation_seconds", tock - tick, self.train_state.step) + ++ def compile_eval_dali(self, ++ batches: Mapping[str, BatchType], ++ nonpadding: jnp.ndarray) -> None: ++ """For DALI datasets, pre-compiles eval step (if not yet compiled). ++ ++ Not required. ++ ++ Pre-compiles the evaluation step for each evaluation dataset, reusing cached ++ compilations where possible. In other words, if multiple evaluation datasets ++ have equivalent shapes/dtypes for the batch and initial metrics, ++ recompilation will be avoided. ++ ++ If not called before `eval`, compilation will occur automatically on the ++ first step and JAX's "jit cache" will be used to avoid recompilation for ++ future steps. ++ ++ Args: ++ batches: a mapping from evaluation dataset name to a sample batch. The ++ batch may contain dummy values, but the shapes and dtypes must be ++ correct. ++ nonpadding: a dummy boolean array of padding values. ++ """ ++ for eval_name, batch in batches.items(): ++ tick = time.time() ++ cache_key: BatchSpec = FrozenDict(jax.eval_shape(lambda: batch)) # pylint:disable=cell-var-from-loop ++ if cache_key not in self._compiled_eval_step_cache: ++ if jax.process_count() > 1: ++ batch = partitioning.host_local_array_to_global_array( ++ batch, self._partitioner.mesh, ++ self._partitioner.data_partition_spec) ++ self._compiled_eval_step_cache[cache_key] = self._partitioner.compile( ++ self._partitioned_score_step, self.train_state, batch, nonpadding) ++ self._compiled_eval_steps[eval_name] = self._compiled_eval_step_cache[ ++ cache_key] ++ tock = time.time() ++ self.eval_metrics_managers[eval_name].write_scalar( ++ "timing/compilation_seconds", tock - tick, self.train_state.step) ++ + @property + @abc.abstractmethod + def _partitioned_train_step(self) -> PartitionedTrainCallable: +@@ -828,6 +960,10 @@ def eval_step(model: models.BaseModel, train_state: train_state_lib.TrainState, + # pytype: enable=wrong-arg-types + return metrics + ++def score_step(model: models.BaseModel, train_state: train_state_lib.TrainState, ++ batch: Mapping[str, jnp.ndarray], nonpaddings: jnp.ndarray) -> MetricMapType: ++ metrics = model.get_metrics_per_batch(train_state.params, batch) ++ return metrics, nonpaddings + + def train_with_lr( + train_state: train_state_lib.TrainState, +@@ -940,6 +1076,16 @@ class Trainer(BaseTrainer): + self._partitioner.data_partition_spec), + out_axis_resources=None) + ++ @cached_property ++ def _partitioned_score_step(self) -> PartitionedEvalCallable: ++ return self._partitioner.partition( ++ lambda train_state, batch, nonpaddings: score_step(self._model, train_state, ++ batch, nonpaddings), ++ in_axis_resources=(self._train_state_axes, ++ self._partitioner.data_partition_spec, ++ self._partitioner.data_partition_spec), ++ out_axis_resources=None) ++ + + def _warn_action_not_run(action, task, metric): + logging.warning( +-- +2.25.1 + + +From fbb093fa5851499289804594078b473696c52476 Mon Sep 17 00:00:00 2001 +From: ashors1 +Date: Wed, 1 Nov 2023 11:17:53 -0700 +Subject: [PATCH 2/2] fix bug in rebase + +--- + t5x/train.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/t5x/train.py b/t5x/train.py +index c0263ca..600e3a4 100644 +--- a/t5x/train.py ++++ b/t5x/train.py +@@ -310,7 +310,7 @@ def train( + ) + + if prepare_train_iter_fn: +- train_iter = utils.prepare_train_iter( ++ train_iter = prepare_train_iter_fn( + train_iter, + checkpoint_cfg=checkpoint_cfg, + partitioner=partitioner, +-- +2.25.1 + diff --git a/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch b/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch new file mode 100644 index 000000000..9250d85e0 --- /dev/null +++ b/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch @@ -0,0 +1,29 @@ +From 0678aa3e8b7d5f7365aebb191f2e0794ca95a4b6 Mon Sep 17 00:00:00 2001 +From: ashors1 +Date: Mon, 17 Apr 2023 13:15:01 -0700 +Subject: [PATCH] make strict and fallback_to_scratch args configurable to add + support for partial checkpoint restore + +--- + t5x/train.py | 5 +++++ + 1 file changed, 5 insertions(+) + +diff --git a/t5x/train.py b/t5x/train.py +index 61682ed..77e0860 100644 +--- a/t5x/train.py ++++ b/t5x/train.py +@@ -354,6 +354,11 @@ def train( + checkpoint_cfg.save and checkpoint_cfg.save.save_dataset + ), + state_transformation_fns=state_transforms_for_restore, ++ strict=(checkpoint_cfg.restore.strict ++ if checkpoint_cfg.restore is not None else True ++ ), ++ fallback_to_scratch=(checkpoint_cfg.restore.fallback_to_scratch ++ if checkpoint_cfg.restore is not None else False) + ) + ] + # 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set. +-- +2.25.1 + diff --git a/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch b/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch new file mode 100644 index 000000000..e1c034960 --- /dev/null +++ b/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch @@ -0,0 +1,3553 @@ +From 6ac19ca55511e08e8751ab95e312df0399c19c5b Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Mon, 24 Apr 2023 10:18:29 -0700 +Subject: [PATCH 01/16] Added transformer engine support and GPU optimizations + +Co-authored-by: Sahil Jain +Co-authored-by: Terry Kong +Co-authored-by: Yu-Hang Tang +Co-authored-by: Ming Huang +Co-authored-by: Frederic Bastien +Co-authored-by: Sharath Turuvekere Sreenivas +Co-authored-by: Xiaowei Ren +Co-authored-by: Ryan Jeng +Co-authored-by: Reese Wang +--- + README.md | 5 +- + docs/usage/gpu-usage.md | 144 ++++++--- + t5x/contrib/gpu/Dockerfile | 6 +- + t5x/contrib/gpu/README.md | 90 +----- + t5x/contrib/gpu/T5X_TE_README.md | 99 ++++++ + .../scripts_gpu/example_slurm_ft_frompile.sub | 46 +-- + .../example_slurm_pretrain_pile.sub | 38 ++- + .../scripts_gpu/multiprocess_ft_frompile.sh | 82 +++-- + .../scripts_gpu/multiprocess_pretrain_pile.sh | 120 ++++---- + .../gpu/scripts_gpu/singlenode_ft_frompile.sh | 17 +- + .../scripts_gpu/singlenode_pretrain_pile.sh | 17 +- + t5x/contrib/gpu/t5/configs/runs/finetune.gin | 6 +- + .../gpu/t5/configs/runs/finetune_mnli.gin | 6 +- + .../gpu/t5/configs/runs/finetune_squad1.gin | 6 +- + t5x/contrib/gpu/t5/configs/runs/pretrain.gin | 4 +- + t5x/contrib/gpu/t5/network.py | 121 ++++++-- + .../examples/large_mnli2_finetune_adam.gin | 1 + + .../examples/large_squad1_finetune_adam.gin | 1 + + .../examples/small_mnli2_finetune_adam.gin | 1 + + .../examples/small_squad1_finetune_adam.gin | 1 + + .../examples/xl_mnli2_finetune_adam.gin | 1 + + .../examples/xl_squad1_finetune_adam.gin | 1 + + t5x/models.py | 161 ++++++++-- + t5x/partitioning.py | 3 + + t5x/te_helper.py | 284 ++++++++++++++++++ + t5x/train.py | 59 ++++ + t5x/train_state.py | 3 +- + t5x/trainer.py | 45 +-- + 28 files changed, 1045 insertions(+), 323 deletions(-) + create mode 100644 t5x/contrib/gpu/T5X_TE_README.md + create mode 100644 t5x/te_helper.py + +diff --git a/README.md b/README.md +index 916b684..bb0146a 100644 +--- a/README.md ++++ b/README.md +@@ -72,8 +72,11 @@ be read by TensorBoard. + ## GPU Usage + Note: NVIDIA has released an updated version of this repository with H100 FP8 support and broad GPU performance improvements. Please visit the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository for more details and usage instructions. + +-T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at [t5x/contrib/gpu](https://github.com/google-research/t5x/blob/main/t5x/contrib/gpu/README.md). The `t5x/contrib/gpu/scripts_gpu` folder contains example scripts for pretraining T5X on [The Pile](https://pile.eleuther.ai/) and for finetuning on SQuAD and MNLI. These scripts and associated `gin` configurations also contain additional GPU optimizations for better throughput. More examples and instructions can be found in the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository maintained by NVIDIA with H100 FP8 support and broad GPU performance improvements. ++T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at [Rosetta T5X README](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x/README.md). The `t5x/contrib/gpu/scripts_gpu` folder contains example scripts for pretraining T5X on [The Pile](https://pile.eleuther.ai/) and for finetuning on SQuAD and MNLI. These scripts and associated `gin` configurations also contain additional GPU optimizations for better throughput. More examples and instructions can be found in the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository maintained by NVIDIA with H100 FP8 support and broad GPU performance improvements. + ++We now have support for: ++- [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) FP8 ++- Improved performance on H100/A100 GPUs + + ## Installation + +diff --git a/docs/usage/gpu-usage.md b/docs/usage/gpu-usage.md +index dedcd88..a9974e1 100644 +--- a/docs/usage/gpu-usage.md ++++ b/docs/usage/gpu-usage.md +@@ -1,4 +1,4 @@ +-# GPU Scripts ++# GPU Scripts and Usage + + # Warning! + An updated version of T5x with optimized GPU performance (18-80% perf gains!) and new features, including FP8 with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and H100 support can be found here: [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x). +@@ -6,12 +6,14 @@ An updated version of T5x with optimized GPU performance (18-80% perf gains!) an + **NVIDIA no longer recommends using this repository and won't be updating it further.** + ----- + +-The [t5x/contrib/gpu](../../t5x/contrib/gpu) directory contains scripts optimized for GPU usage. ++The [t5x/contrib/gpu/scripts_gpu](../../t5x/contrib/gpu/scripts_gpu) directory contains scripts optimized for GPU usage and includes FP8 support via [Transformer Engine](https://github.com/NVIDIA/TransformerEngine). + + Install with `pip install -r pile_requirements.txt` to get all pile dependencies. + + ## Building the container +-The Dockerfile in `t5x/contrib/gpu` given will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh ` ++We provide a fully built and ready-to-use container here: [ghcr.io/nvidia/t5x:te-fp8-reference](ghcr.io/nvidia/t5x:te-fp8-reference) ++If you'd like you build your own, ++The Dockerfile in `t5x/contrib/gpu` will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh ` + + ## Running interactively + Note: this should only be done with singlenode jobs and/or for downloading the pile. Use `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh`. This takes arguments for the URL to pull a container from and the location of the dataset directory to mount. For example: +@@ -19,7 +21,7 @@ Note: this should only be done with singlenode jobs and/or for downloading the p + `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh [URL] /my/dataset/dir` + + ## Downloading The Pile +-Run `download_the_pile.py` to download the pile. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use. ++We use The Pile for our pretraining experiments. If you would like to as well, run `download_the_pile.py` to download it. The download is approximately 1TB. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use. + + ## Single Node runs + Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build a T5X model with the Adam optimizer and relevant parameters. These will allow multi-gpu on one host. +@@ -27,61 +29,127 @@ Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build + ## Multi Node runs + For a SLURM+pyxis cluster, `example*.sub` files provide example slurm submit files (edit with your details), which call `multiprocess*.sh` to execute training. You can add a binding script in the `.sub` file for your cluster, or remove it entirely (dropping some throughput) + +-## Convergence +-For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100 80G) nodes. ++## Convergence and performance ++For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2016-2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100-SXM4-80G) and H100-SXM-80G nodes. + +-| size | #GPUs | TP | BS / GPU | Sequences/Sec | Estimated Walltime | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | +-| ---- | ----- | ----- | -------- | ------------- | ------------------ | ------------------ | ------------------ | --------------- | +-| small| 8 | 1 | 256 | ~3168 | 7.48 days | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | +-| large| 64 | 1 | 32 | ~3886 | 6.10 days | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) | +-| xl | 256 | 1 | 8 | ~3652 | 6.49 days | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) | +-| xxl | 512 | 8 | 36 | ~1346 | 19.81 days | N/A(partial run) | N/A(partial run) | N/A(partial run)| ++| size | GPU | Precision | #GPUs | TP | BS / GPU | Sequences/Sec | Seq/Sec/GPU | Est. Walltime | GPU-days | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | Config | ++| ---- | ------------ | --------- | ----- | ----- | -------- | ------------- | ----------- | ------------- | -------- |------------------ | ------------------ | --------------- | ---- | ++| [T5-v1.1-small](../t5/t5_1_1/small.gin) | A100 80G SXM | bf16 | 8 | 1 | 256 | ~5712 | 714 | 4.2 days | 33 | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | [pile](../t5/t5_1_1/examples/small_pile_pretrain.gin) ++| [T5-v1.1-large](../t5/t5_1_1/large.gin) | A100 80G SXM | bf16 | 64 | 1 | 32 | ~4853 | 75.8 | 4.8 days | 309 | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) ++| [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 144 | 1 | 8 | ~3021 | 21.0 | 7.9 days | 1,133 | N/A(perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) ++| [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 256 | 1 | 8 | ~4322 | 16.9 | 5.5 days | 1,408 | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) ++| [T5-v1.1-xxl](../t5/t5_1_1/xxl.gin) | A100 80G SXM | bf16 | 512 | 8 | 36 | ~1887 | 3.69 | 12.6 days | 6,431 |N/A(partial run) | N/A(partial run) | |[pile](../t5/t5_1_1/examples/xxl_pile_pretrain.gin) ++| [T5-v1.1-large](../t5/t5_1_1/large.gin) | **H100 80G SXM** | TE-fp8 | 64 | 1 | 32 | ~10156 | **158.7** | **2.3 days** | **147** | 89.1% | 86.36 / 93.5 | |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) ++| [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 144 | 1 | 14 | ~7257 | **50.4** | **3.3 days** | **475** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) ++| [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 256 | 1 | 8 | ~9688 | **37.8** | **2.4 days** | **614** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) + + Note: Convergence (as shown in log) was not necessarily done with the hardware topology listed, but the listed topology is tested. Estimated Walltime is calculated assuming full throughput (seq/sec) continuously. In practice, there are compilation overheads at the beginning of each run/restart(in cluster settings) + checkpointing overheads (if any). + +-(More perf improvements coming soon!) +- + Other hyperparameters are specified in the associated pile `gin` files in the `contrib/gpu/t5/t5_1_1/examples` directory. + + ## Pretraining run commands + +-### Singlenode +-small: ++### Multinode ++Arguments are set by environment variable as such: + +-`t5x/contrib/gpu/t5/scripts_gpu/singlenode_pretrain_pile.sh small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR} {GRADIENT_ACCUMULATION (1 by default)}` ++`PREC={PRECISION} T5_SIZE={SIZE} BSIZE_PER_GPU={BSIZE} ..... sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {GPUS_PER_NODE}` + +-Finetuning: +-MNLI v2: +-`t5x/contrib/gpu/t5/scripts_gpu/singlenode_ft_frompile.sh mnli2 small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR(to restore pretrained checkpoint from)} {GRADIENT_ACCUMULATION}` ++All parameters can be found in the relevant script. + ++### Example Pretraining Commands ++Assumes 8GPU 80GB A100/H100 Nodes. `ENABLE_FP8` uses transformer engine (included in container) and requires H100 + +-### Multinode +-Arguments are as such: ++* Note: To use, FP8 set `ENABLE_FP8` to `1`. This will automatically set `PREC` to `bfloat16` as is required by internals for `FP8` usage. ++#### [T5-v1.1-small](../t5/t5_1_1/small.gin) (60M): ++```sh ++PREC=bfloat16 T5_SIZE=small BSIZE_PER_GPU=256 TRAIN_STEPS=1000000 NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ ++sbatch -N1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub ++``` + +-`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` ++#### [T5-v1.1-large](../t5/t5_1_1/large.gin) (770M): ++```sh ++PREC=bfloat16 T5_SIZE=large BSIZE_PER_GPU=32 TRAIN_STEPS=1000000 NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ ++sbatch -N8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub ++``` + +-small: ++#### [T5-v1.1-xl](../t5/t5_1_1/xl.gin) (3B): ++```sh ++PREC=bfloat16 T5_SIZE=large BSIZE_PER_GPU=8 TRAIN_STEPS=1000000 NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ ++sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub ++``` + +-`sbatch -N 1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub small bfloat16 8 256 {MODEL_DIR} 1 1` ++### Example Finetuning Commands ++Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from. + +-large: ++#### MNLI v2: ++```sh ++FT_TASK=mnli2 PREC=bfloat16 T5_SIZE={SIZE} BSIZE_PER_GPU={BSIZE} NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ ++sbatch -N{NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub ++``` + +-`sbatch -N 8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub large bfloat16 8 32 {MODEL_DIR} 1 1` ++#### SQuAD v1.1: ++```sh ++FT_TASK=squad1 PREC=bfloat16 T5_SIZE={SIZE} BSIZE_PER_GPU={BSIZE} NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ ++sbatch -N{NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub + +-xl: ++``` + +-`sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub xl bfloat16 8 8 {MODEL_DIR} 1 1` ++## Performance Settings: ++There are 3 major performance settings: `ENABLE_FP8`, `FUSE_QKV` and `TRANSPOSE_BS` (all of which are controllable via env var in the commands above). ++We recommend always enabling `TRANSPOSE_BS` (default), but only using `FUSE_QKV` when using `ENABLE_FP8` for optimal performance. + +-Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from. ++On all finetuning runs, we use a Global Batch Size of 256 with bfloat16 precision + FP8. + +-MNLI v2: +- +-`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub mnli2 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` +- +-SQuAD v1.1 ++WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. + +-`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub squad1 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` ++### Singlenode (single process) ++small: + +-On all finetuning runs, we use a Global Batch Size of 128 with bfloat16 precision. ++```sh ++t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh \ ++ small \ ++ bfloat16 \ ++ 8 \ ++ 256 \ ++ {LOGDIR - create before running} \ ++ {MODEL_DIR} \ ++ {GRADIENT_ACCUMULATION (1 by default)} \ ++ {ENABLE_FP8 (1 by default)} \ ++ {TRANSPOSE_BS (1 by default)} \ ++ {FUSE_QKV (1 by default)} \ ++ {PACK (0 by default)} ++``` ++ ++WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. + +-WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. +\ No newline at end of file ++Finetuning: ++MNLI v2: ++```sh ++t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh \ ++ mnli2 \ ++ small \ ++ bfloat16 \ ++ 8 \ ++ 256 \ ++ {LOGDIR - create before running} \ ++ {MODEL_DIR(to restore pretrained checkpoint from)} \ ++ {GRADIENT_ACCUMULATION (1 by default)} \ ++ {MAKE_FT_DIR (false by default)} ++ {ENABLE_FP8 (1 by default)} \ ++ {TRANSPOSE_BS (1 by default)} \ ++ {FUSE_QKV (1 by default)} \ ++ {PACK (0 by default)} ++``` ++ ++WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. ++# Changelog ++- Added Transformer Engine + FP8 support ++- Added the Transposed Batch-Sequence GPU optimization ++- A100 Perf gains! (BF16) ++ - 80% speedup - T5-small ++ - 23% speedup - T5-large ++ - 18% speedup - T5-xl ++ - 40% speedup - T5-xxl ++- H100 FP8 support, with gains over A100 ++ - 2.08x faster - T5-large (FP8) ++ - 2.24x faster - T5-xl (FP8) +diff --git a/t5x/contrib/gpu/Dockerfile b/t5x/contrib/gpu/Dockerfile +index 4ab560e..5c90435 100644 +--- a/t5x/contrib/gpu/Dockerfile ++++ b/t5x/contrib/gpu/Dockerfile +@@ -1,14 +1,12 @@ +-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:22.08-tf2-py3 ++ARG FROM_IMAGE_NAME=ghcr.io/nvidia/jax-toolbox-internal:5061977725-te + FROM ${FROM_IMAGE_NAME} + +-# Install the latest jax +-RUN pip install jax[cuda]==0.4.1 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +- + # setup directory paths for T5x + ENV TFDS_DATA_DIR=/t5x_home/datasets/ + ENV T5X_DIR=/t5x_home/ + ENV T5X_WORKSPACE_DIR=/t5x_home/workspace + ENV PYTHONPATH=/t5x_home/ ++ + WORKDIR /t5x_home + + # install the requirements for T5x +diff --git a/t5x/contrib/gpu/README.md b/t5x/contrib/gpu/README.md +index 6e7cc57..7208713 100644 +--- a/t5x/contrib/gpu/README.md ++++ b/t5x/contrib/gpu/README.md +@@ -1,90 +1,2 @@ + # GPU Scripts +- +-# Warning! +-An updated version of T5x with optimized GPU performance and new features, including FP8 with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and H100 support can be found here: [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x). +------ +-**NVIDIA no longer recommends using this repository and won't be updating it further.** +------ +- +-The [t5x/contrib/gpu/scripts_gpu](scripts_gpu) directory contains scripts optimized for GPU usage. +- +-To get all dependencies for the Pile dataset, install with the `gpu` extra: +-```bash +-pip install '.[gpu]' +-``` +- +-## Building the container +-The Dockerfile in `t5x/contrib/gpu` given will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh ` +- +-## Running interactively +-Note: this should only be done with singlenode jobs and/or for downloading the pile. Use `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh`. This takes arguments for the URL to pull a container from and the location of the dataset directory to mount. For example: +- +-`t5x/contrib/gpu/docker/interactive_pull_and_launch.sh [URL] /my/dataset/dir` +- +-## Downloading The Pile +-Run `download_the_pile.py` to download the pile. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use. +- +-## Single Node runs +-Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build a T5X model with the Adam optimizer and relevant parameters. These will allow multi-gpu on one host. +- +-## Multi Node runs +-For a SLURM+pyxis cluster, `example*.sub` files provide example slurm submit files (edit with your details), which call `multiprocess*.sh` to execute training. You can add a binding script in the `.sub` file for your cluster, or remove it entirely (dropping some throughput) +- +-## Convergence +-For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100 80G) nodes. +- +-| size | #GPUs | TP | BS / GPU | Sequences/Sec | Estimated Walltime | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | +-| ---- | ----- | ----- | -------- | ------------- | ------------------ | ------------------ | ------------------ | --------------- | +-| small| 8 | 1 | 256 | ~3168 | 7.48 days | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | +-| large| 64 | 1 | 32 | ~3886 | 6.10 days | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) | +-| xl | 256 | 1 | 8 | ~3652 | 6.49 days | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) | +-| xxl | 512 | 8 | 36 | ~1346 | 19.81 days | N/A(partial run) | N/A(partial run) | N/A(partial run)| +- +-Note: Convergence (as shown in log) was not necessarily done with the hardware topology listed, but the listed topology is tested. Estimated Walltime is calculated assuming full throughput (seq/sec) continuously. In practice, there are compilation overheads at the beginning of each run/restart(in cluster settings) + checkpointing overheads (if any). +- +-(More perf improvements coming soon!) +- +-Other hyperparameters are specified in the associated pile `gin` files in the `contrib/gpu/t5/t5_1_1/examples` directory. +- +-## Pretraining run commands +- +-### Singlenode +-small: +- +-`t5x/contrib/gpu/t5/scripts_gpu/singlenode_pretrain_pile.sh small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR} {GRADIENT_ACCUMULATION (1 by default)}` +- +-Finetuning: +-MNLI v2: +-`t5x/contrib/gpu/t5/scripts_gpu/singlenode_ft_frompile.sh mnli2 small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR(to restore pretrained checkpoint from)} {GRADIENT_ACCUMULATION}` +- +- +-### Multinode +-Arguments are as such: +- +-`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` +- +-small: +- +-`sbatch -N 1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub small bfloat16 8 256 {MODEL_DIR} 1 1` +- +-large: +- +-`sbatch -N 8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub large bfloat16 8 32 {MODEL_DIR} 1 1` +- +-xl: +- +-`sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub xl bfloat16 8 8 {MODEL_DIR} 1 1` +- +-Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from. +- +-MNLI v2: +- +-`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub mnli2 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` +- +-SQuAD v1.1 +- +-`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub squad1 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` +- +-On all finetuning runs, we use a Global Batch Size of 128 with bfloat16 precision. +- +-WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. ++This folder containers scripts that help run optimized T5x code on GPU with FP8 support. Please refer to [Rosetta T5X README](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x/README.md) for further guides. +diff --git a/t5x/contrib/gpu/T5X_TE_README.md b/t5x/contrib/gpu/T5X_TE_README.md +new file mode 100644 +index 0000000..f182f9b +--- /dev/null ++++ b/t5x/contrib/gpu/T5X_TE_README.md +@@ -0,0 +1,99 @@ ++# T5X with Transformer Engine Summary # ++ ++**Highlight:** ++1. Add `TransformerEngineHelper` to allow users to switch with or without Transformer Engine. ++2. Add the feature of transposing batch_size and sequence to accelerate performance. ++2. Hide FP8 metadata in `flax_mutable`. The flax_mutable is a variable collection that originally is declared by T5X. ++ ++## The *.gin files ## ++They are configurations to set up T5X. The major change is to replace the AdaFactor optimizer with AdamW because of performance concerns. In old XLA, using AdaFactor will generate a lot of D2D copies and slow down the performance. Although the issue was resolved, we used AdamW to verify convergence and performance tests for now. ++ ++## network.py ## ++1. The `TransformerEngineHelper` is a singleton to manage ON/OFF Transformer Engine, to hide the if-else statement inside. The pseudo code is like: ++ ```python ++ class TransformerEngineHelper: ++ @staticmethod: ++ def foo(x): ++ if _IS_TRANSFORMER_ENGINE_INSTALLED and use_te: ++ y = TransformerEngine.foo(x) ++ else: ++ y = T5X.foo(x) ++ return y ++ ``` ++2. The input tensor is BATCH_SEQ_HIDDEN format (i.e., batch_size, sequence, ...) by default. If `cfg.transpose_batch_sequence` is True, transpose input tensor to SEQ_BATCH_HIDDEN format because using SEQ_BATCH_HIDDEN is faster for now. It might not be necessary after integrating cuDNN MHA. And according to `output_format` to decide whether to transpose output tensor or not. It is for easy debugging. ++3. The reason to rename the mask from `encoder_mask`/`decoder_mask` to `attention_mask` is to align the kwargs of TransformerLayer between T5X and Transformer Engine. The original T5X TransformerLayer has a different parameter list than the Transformer Engine. It blocks us from making a functor to switch two of them. The pseudo code is like: ++ ```python ++ if use_te: ++ TransformerLayer = te.TransformerLayer ++ else: ++ TransformerLayer = t5x.TransformerLayer ++ ++ y = TransformerLayer(x, attention_mask=mask) ++ ``` ++4. The `TransformerEngineHelper.get_attn_mask(*_mask)` is used to convert the T5X mask to the format required by Transformer Engine. In T5X, `1` means keep and `0` means drop, but in Transformer Engine, the meaning is reversed. ++ ++## utils.py ## ++1. The `jax.eval_shape` has to be wrapped by `TransformerEngineHelper.eval_shape_guard()` because the `ShardingResource` must be set first. Otherwise, xmap cannot infer the shape of each layer of the model, and an exception will be thrown. ++2. The `flax_mutables` is a variable collection that contains FP8 metadata and sharding information (e.g., named logical axis). It is required by FP8 training and tensor parallelism. ++ ++## trainer.py ## ++1. At the code: `grad_fn = jax.value_and_grad(model.loss_fn, argnums=(0, 3), ...)`, the number `0` refers to 1st argument of loss_fn, and the number `3` refers to 4th argument of loss_fn. The 1st argument is input tensor. The 4th argument is the `flax_mutables` which contains FP8 metadata. In order to get the updated FP8 metadata after 1 training step, we need to ask JAX to differentiate `flax_mutables`. Note that, in fact, FP8 metadata is NOT calculated by differentiation. The FP8 metadata is maintained by the Transformer Engine. It is a trick to get the updated FP8 metadata because we didn't find other interfaces or approaches to get it. ++2. At the code: ++ ```diff ++ - initial_flax_mutables = train_state.flax_mutables if train_state.flax_mutables else None ++ + initial_flax_mutables = train_state.flax_mutables if train_state.flax_mutables else {} ++ ``` ++ The `None` should be a T5X bug. It will trigger exceptions if `flax_mutables` needs to be filled into JAX routines. Although T5X declares the `flax_mutables`, it actually doesn't use it. Thus, T5X developers weren't aware of this issue. ++3. The `grad_accum` becomes a list of variable collection because two variables are differentiated. The 1st is model parameters. The 2nd is FP8 metadata. ++4. At the code: ++ ```python ++ grad_accum = (grad_accum[0], ++ TransformerEngineHelper.update_fp8_metas( ++ grad_accum[1], flax_mutables, train_state.step)) ++ ``` ++ It is a workaround due to the T5X (or JAX) bug. We don't know the root-cause yet and don't have time to investigate it. The bug is that T5X always misses 1 time of accumulating gradients. For example, if the accumulation step is 10, T5X should run micro-batch 10 times and accumulate the gradient of each micro-batch but it only accumulates gradient 9 times. If the accumulation step is 1, T5X doesn't update the gradient. Thus, the workaround is to accumulate the gradient 1 time manually. ++ ++## train_state.py ## ++1. Add `flax_mutables_axes`, so xmap can know how to do the sharding for FP8 metadata. ++ ++## train.py ## ++1. Import `TransformerEngineHelper` and initialize it. ++ ++## te_helper.py ## ++1. A new file contains the `TransformerEngineHelper` implementation. Note that it uses Transformer Engine internal API - `FP8Helper.initialize` and `FP8Helper.finalize`. It is a trade off between the number of lines of code changes and the recommended way for enabling FP8 training. The recommended approach is: ++ ```python ++ with te.fp8_autocast(fp8_format, ...): ++ model = Net() ++ variable_collection = model.init(rng, inputs) ++ state = TrainState.create(apply_fn=model.apply, ...) ++ train_epoch(state, dataset) ++ ``` ++ It is equal to: ++ ```python ++ FP8Helper.initialize(fp8_format, ...) # allocate FP8 metadata and setup ++ model = Net() ++ variable_collection = model.init(rng, inputs) ++ state = TrainState.create(apply_fn=model.apply, ...) ++ train_epoch(state, dataset) ++ FP8Helper.finalize() # release FP8 metadata ++ ``` ++ ++## partitioning.py ## ++1. Append the sharding rules needed by Transformer Engine after T5X's rues ++ ++## models.py ## ++1. Add `eval_fn` because a new argument - `flax_mutable` is needed. ++2. Add `predict_batch` because a new argument - `flax_mutable` is needed. ++3. At the code: ++ ```python ++ module.apply( ++ {'params': params, **flax_mutable}, ++ ... ++ ) ++ ``` ++ The module.apply only accepts 1 variable collection, so model parameters and FP8 metadata need to be merged before filled into apply. ++4. The `cache_offset` indicates which dimension is batch_size, for beam-search. Thus, it must be changed if `cfg.transpose_batch_sequence` is True. ++ ++## run_t5x_*.sh ## ++1. They are shell scripts for convenience in running experiments. ++ +diff --git a/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub b/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub +index 966ea69..19ec14e 100755 +--- a/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub ++++ b/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub +@@ -48,19 +48,28 @@ T5X_WORKSPACE_DIR=/t5x_home/workspace + MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR" + + # Add T5x/JAX specific exports +-EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR}" ++EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR},PYTHONPATH=${T5X_DIR}" + #------------------------------------------------------------------------------- + +-# Command line arguments needed by the underlying scripts +-TASK=$1 # mnli2 or squad1, add others with corresponding gin files +-T5_SIZE=$2 # small, base, large, xl, xxl +-PREC="$3" # bfloat16, float32 +-GPUS_PER_NODE=$4 # usually 8 +-BSIZE_PER_GPU=$5 # local batch size/gpu +-MODEL_DIR_LOCAL=$6 # directory to save checkpoints and config dump to +-NUM_MICROBATCHES=$7 # number of gradient accumulation steps +- +-NUM_GPUS=$(( GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) ++FT_TASK=${FT_TASK:-mnli2} ++PREC=${PREC:="bfloat16"} ++T5_SIZE=${T5_SIZE:="large"} ++BSIZE_PER_GPU=${BSIZE_PER_GPU:=32} ++ENC_SL=${ENC_SL:=512} ++DEC_SL=${DEC_SL:=128} ++NUM_MICROBATCHES=${NUM_MICROBATCHES:=1} ++ENABLE_FP8=${ENABLE_FP8:=1} ++TP_SIZE=${TP_SIZE:=1} ++TRANSPOSE_BS=${TRANSPOSE_BS:=1} ++MODEL_DIR=${MODEL_DIR:=model_dir} ++FUSE_QKV=${FUSE_QKV:=1} ++PACK=${PACK:=0} ++ ++export GPUS_PER_NODE=${1:-8} ++export BASE_SCRIPT=${2:-"${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh"} ++export WITH_MP=1 ++ ++NUM_GPUS=$((GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) + + # redirect both stdout and stderr in the same file for ease of analysis + OUTDIR="outputs/multinode/${TASK}_t5_${T5_SIZE}-prec_${PREC}-nodes_${SLURM_JOB_NUM_NODES}-gpus_${NUM_GPUS}-bs_${BSIZE_PER_GPU}-sl_${SL}" +@@ -73,16 +82,17 @@ LOGDIR="${T5X_WORKSPACE_DIR}/${OUTDIR}" + # You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. + # && bash <>/bind.sh --cpu=exclusive --ib=single -- \ + read -r -d '' cmd <>/bind.sh --cpu=exclusive --ib=single -- \ + read -r -d '' cmd < \ +- ${LOG_DIR}/ft_${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}.log ++ --gin.infer_eval/utils.DatasetConfig.batch_size=${BSIZE} \ ++ --gin.train/utils.DatasetConfig.pack=${PACK} \ ++ --gin.train_eval/utils.DatasetConfig.pack=${PACK} \ ++ --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ ++ --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ ++ --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ ++ --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ ++ --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ ++ &> \ ++ ${LOG_DIR}/ft_${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log +diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +index c82322a..def1a1a 100755 +--- a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh ++++ b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +@@ -29,6 +29,11 @@ LOG_DIR=$5 # Output log directory + MODEL_DIR_LOCAL=${6:-"model_dir"} + MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL} + NUM_MICROBATCHES=${7:-0} ++ENABLE_FP8=${8:-1} ++[[ $ENABLE_FP8 -eq 1 ]] && PREC='bfloat16' # Required for t5x te fp8 to work ++TRANSPOSE_BS=${9:-1} ++FUSE_QKV=${10:-1} ++PACK=${11:-0} + + echo $MODEL_DIR + +@@ -49,5 +54,13 @@ python3 -u ${T5X_DIR}/t5x/train.py \ + --gin.train/utils.DatasetConfig.batch_size=${BSIZE} \ + --gin.trainer.Trainer.num_microbatches=${NUM_MICROBATCHES} \ + --gin.train_eval/utils.DatasetConfig.batch_size=${BSIZE} \ +- --gin.infer_eval/utils.DatasetConfig.batch_size=${BSIZE} &> \ +- ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}.log ++ --gin.infer_eval/utils.DatasetConfig.batch_size=${BSIZE} \ ++ --gin.train/utils.DatasetConfig.pack=${PACK} \ ++ --gin.train_eval/utils.DatasetConfig.pack=${PACK} \ ++ --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ ++ --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ ++ --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ ++ --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ ++ --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ ++ &> \ ++ ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log +diff --git a/t5x/contrib/gpu/t5/configs/runs/finetune.gin b/t5x/contrib/gpu/t5/configs/runs/finetune.gin +index a76d809..0d68c29 100644 +--- a/t5x/contrib/gpu/t5/configs/runs/finetune.gin ++++ b/t5x/contrib/gpu/t5/configs/runs/finetune.gin +@@ -41,6 +41,7 @@ TASK_FEATURE_LENGTHS = %gin.REQUIRED + MIXTURE_OR_TASK_MODULE = %gin.REQUIRED + TRAIN_STEPS = %gin.REQUIRED + INITIAL_CHECKPOINT_PATH = %gin.REQUIRED ++RESET_STATE_AFTER = None # a flag to reset optimizer and fp8 states (if exist) after a set number of sets (i.e. after pretraining) + + # Commonly overridden + DROPOUT_RATE = 0.1 +@@ -81,6 +82,7 @@ train_script.train: + use_hardware_rng = %USE_HARDWARE_RNG + summarize_config_fn = @gin_utils.summarize_gin_config + inference_evaluator_cls = @seqio.Evaluator ++ reset_state_after = %RESET_STATE_AFTER + + partitioning.PjitPartitioner: + num_partitions = 1 +@@ -103,7 +105,7 @@ train/utils.DatasetConfig: + shuffle = True + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS +- pack = True ++ pack = False + module = %MIXTURE_OR_TASK_MODULE + + train_eval/utils.DatasetConfig: +@@ -114,7 +116,7 @@ train_eval/utils.DatasetConfig: + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS +- pack = True ++ pack = False + module = %MIXTURE_OR_TASK_MODULE + + infer_eval/utils.DatasetConfig: +diff --git a/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin b/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin +index 7cebbb6..da44827 100644 +--- a/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin ++++ b/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin +@@ -40,6 +40,7 @@ MIXTURE_OR_TASK_NAME = %gin.REQUIRED + TASK_FEATURE_LENGTHS = %gin.REQUIRED + MIXTURE_OR_TASK_MODULE = %gin.REQUIRED + TRAIN_STEPS = %gin.REQUIRED ++RESET_STATE_AFTER = None # a flag to reset optimizer and fp8 states (if exist) after a set number of sets (i.e. after pretraining) + INITIAL_CHECKPOINT_PATH = %gin.REQUIRED + + # Commonly overridden +@@ -80,6 +81,7 @@ train_script.train: + use_hardware_rng = %USE_HARDWARE_RNG + summarize_config_fn = @gin_utils.summarize_gin_config + inference_evaluator_cls = @seqio.Evaluator ++ reset_state_after = %RESET_STATE_AFTER + + partitioning.PjitPartitioner: + num_partitions = 1 +@@ -102,7 +104,7 @@ train/utils.DatasetConfig: + shuffle = True + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS +- pack = True ++ pack = False + module = %MIXTURE_OR_TASK_MODULE + + train_eval/utils.DatasetConfig: +@@ -113,7 +115,7 @@ train_eval/utils.DatasetConfig: + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS +- pack = True ++ pack = False + module = %MIXTURE_OR_TASK_MODULE + + infer_eval/utils.DatasetConfig: +diff --git a/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin b/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin +index 4ea952c..b1d8e7e 100644 +--- a/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin ++++ b/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin +@@ -40,6 +40,7 @@ MIXTURE_OR_TASK_NAME = %gin.REQUIRED + TASK_FEATURE_LENGTHS = %gin.REQUIRED + MIXTURE_OR_TASK_MODULE = %gin.REQUIRED + TRAIN_STEPS = %gin.REQUIRED ++RESET_STATE_AFTER = None # a flag to reset optimizer and fp8 states (if exist) after a set number of sets (i.e. after pretraining) + INITIAL_CHECKPOINT_PATH = %gin.REQUIRED + + # Commonly overridden +@@ -80,6 +81,7 @@ train_script.train: + use_hardware_rng = %USE_HARDWARE_RNG + summarize_config_fn = @gin_utils.summarize_gin_config + inference_evaluator_cls = @seqio.Evaluator ++ reset_state_after = %RESET_STATE_AFTER + + partitioning.PjitPartitioner: + num_partitions = 1 +@@ -102,7 +104,7 @@ train/utils.DatasetConfig: + shuffle = True + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS +- pack = True ++ pack = False + module = %MIXTURE_OR_TASK_MODULE + + train_eval/utils.DatasetConfig: +@@ -113,7 +115,7 @@ train_eval/utils.DatasetConfig: + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS +- pack = True ++ pack = False + module = %MIXTURE_OR_TASK_MODULE + + infer_eval/utils.DatasetConfig: +diff --git a/t5x/contrib/gpu/t5/configs/runs/pretrain.gin b/t5x/contrib/gpu/t5/configs/runs/pretrain.gin +index de12864..e9807f9 100644 +--- a/t5x/contrib/gpu/t5/configs/runs/pretrain.gin ++++ b/t5x/contrib/gpu/t5/configs/runs/pretrain.gin +@@ -73,7 +73,7 @@ train/utils.DatasetConfig: + shuffle = %SHUFFLE_TRAIN_EXAMPLES + seed = None # use a new seed each run/restart + use_cached = %USE_CACHED_TASKS +- pack = True ++ pack = False + module = %MIXTURE_OR_TASK_MODULE + + train_eval/utils.DatasetConfig: +@@ -84,7 +84,7 @@ train_eval/utils.DatasetConfig: + shuffle = False + seed = 42 + use_cached = %USE_CACHED_TASKS +- pack = True ++ pack = False + module = %MIXTURE_OR_TASK_MODULE + + utils.CheckpointConfig: +diff --git a/t5x/contrib/gpu/t5/network.py b/t5x/contrib/gpu/t5/network.py +index dd61078..2be008c 100644 +--- a/t5x/contrib/gpu/t5/network.py ++++ b/t5x/contrib/gpu/t5/network.py +@@ -13,14 +13,18 @@ + # limitations under the License. + + """T5.1.1 Transformer model.""" +- + from typing import Any, Sequence ++from enum import Enum + + from flax import linen as nn + from flax import struct + import jax.numpy as jnp + from t5x.contrib.gpu.t5 import layers ++from t5x.te_helper import TransformerEngineHelper + ++class SeqDataFormat(Enum): ++ BATCH_SEQ_HIDDEN = 'bsh' ++ SEQ_BATCH_HIDDEN = 'sbh' + + @struct.dataclass + class T5Config: +@@ -43,6 +47,10 @@ class T5Config: + float32_attention_logits: bool = False + # Whether to scale attention logits by sqrt(d_k). Default to False for adafactor + scale_attn_logits: bool = False ++ # Whether to transpose batch and sequence to avoid explicit transposes in MHA ++ transpose_batch_sequence: bool = False ++ # Whether to fuse the QKV proj in MHA ++ fuse_qkv_params: bool = False + + + class EncoderLayer(nn.Module): +@@ -51,12 +59,13 @@ class EncoderLayer(nn.Module): + relative_embedding: nn.Module + + @nn.compact +- def __call__(self, inputs, encoder_mask=None, deterministic=False): ++ def __call__(self, inputs, attention_mask=None, deterministic=False): + cfg = self.config + + # Relative position embedding as attention biases. +- encoder_bias = self.relative_embedding(inputs.shape[-2], inputs.shape[-2], +- True) ++ sequence_dim = 0 if cfg.transpose_batch_sequence else 1 ++ encoder_bias = self.relative_embedding( ++ inputs.shape[sequence_dim], inputs.shape[sequence_dim], True) + + # Attention block. + assert inputs.ndim == 3 +@@ -72,7 +81,7 @@ class EncoderLayer(nn.Module): + float32_logits=cfg.float32_attention_logits, + name='attention', + scale_attn_logits=cfg.scale_attn_logits)( +- x, x, encoder_mask, encoder_bias, deterministic=deterministic) ++ x, x, attention_mask, encoder_bias, deterministic=deterministic) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) +@@ -105,7 +114,7 @@ class DecoderLayer(nn.Module): + def __call__(self, + inputs, + encoded, +- decoder_mask=None, ++ attention_mask=None, + encoder_decoder_mask=None, + deterministic=False, + decode=False, +@@ -113,7 +122,8 @@ class DecoderLayer(nn.Module): + cfg = self.config + + # Relative position embedding as attention biases. +- l = max_decode_length if decode and max_decode_length else inputs.shape[-2] ++ sequence_dim = 0 if cfg.transpose_batch_sequence else 1 ++ l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim] + decoder_bias = self.relative_embedding(l, l, False) + + # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] +@@ -132,7 +142,7 @@ class DecoderLayer(nn.Module): + scale_attn_logits=cfg.scale_attn_logits)( + x, + x, +- decoder_mask, ++ attention_mask, + decoder_bias, + deterministic=deterministic, + decode=decode) +@@ -185,8 +195,11 @@ class Encoder(nn.Module): + def __call__(self, + encoder_input_tokens, + encoder_mask=None, +- deterministic=False): +- cfg = self.config ++ deterministic=False, ++ output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): ++ ++ cfg = TransformerEngineHelper.get_t5x_config(self.config) ++ + assert encoder_input_tokens.ndim == 2 # [batch, length] + rel_emb = layers.RelativePositionBiases( + num_buckets=32, +@@ -199,6 +212,9 @@ class Encoder(nn.Module): + + # [batch, length] -> [batch, length, emb_dim] + x = self.shared_embedding(encoder_input_tokens.astype('int32')) ++ if cfg.transpose_batch_sequence: ++ # [batch, length, emb_dim] -> [length, batch, emb_dim] ++ x = x.transpose((1, 0, 2)) + x = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic) +@@ -206,12 +222,20 @@ class Encoder(nn.Module): + + for lyr in range(cfg.num_encoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] +- x = EncoderLayer( ++ encoder_lyr = TransformerEngineHelper.get_encoder_layer( + config=cfg, relative_embedding=rel_emb, +- name=f'layers_{lyr}')(x, encoder_mask, deterministic) ++ name=f'layers_{lyr}', original_cls=EncoderLayer) ++ x = encoder_lyr(x, attention_mask=encoder_mask, deterministic=deterministic) + + x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) +- return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) ++ x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) ++ ++ if (cfg.transpose_batch_sequence and output_format is SeqDataFormat.BATCH_SEQ_HIDDEN) or \ ++ (not cfg.transpose_batch_sequence and output_format is SeqDataFormat.SEQ_BATCH_HIDDEN): ++ x = x.transpose((1, 0, 2)) ++ ++ return x ++ + + + class Decoder(nn.Module): +@@ -228,8 +252,16 @@ class Decoder(nn.Module): + encoder_decoder_mask=None, + deterministic=False, + decode=False, +- max_decode_length=None): +- cfg = self.config ++ max_decode_length=None, ++ encoded_format=SeqDataFormat.BATCH_SEQ_HIDDEN, ++ output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): ++ cfg = TransformerEngineHelper.get_t5x_config(self.config) ++ ++ if (cfg.transpose_batch_sequence and encoded_format is SeqDataFormat.BATCH_SEQ_HIDDEN) or \ ++ (not cfg.transpose_batch_sequence and encoded_format is SeqDataFormat.SEQ_BATCH_HIDDEN): ++ encoded = encoded.transpose((1, 0, 2)) ++ ++ + assert decoder_input_tokens.ndim == 2 # [batch, len] + rel_emb = layers.RelativePositionBiases( + num_buckets=32, +@@ -242,6 +274,10 @@ class Decoder(nn.Module): + + # [batch, length] -> [batch, length, emb_dim] + y = self.shared_embedding(decoder_input_tokens.astype('int32')) ++ if cfg.transpose_batch_sequence: ++ # [batch, length, emb_dim] -> [length, batch, emb_dim] ++ y = y.transpose((1, 0, 2)) ++ + y = nn.Dropout( + rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic) +@@ -249,15 +285,16 @@ class Decoder(nn.Module): + + for lyr in range(cfg.num_decoder_layers): + # [batch, length, emb_dim] -> [batch, length, emb_dim] +- y = DecoderLayer( +- config=cfg, relative_embedding=rel_emb, name=f'layers_{lyr}')( +- y, +- encoded, +- decoder_mask=decoder_mask, +- encoder_decoder_mask=encoder_decoder_mask, +- deterministic=deterministic, +- decode=decode, +- max_decode_length=max_decode_length) ++ decoder_lyr = TransformerEngineHelper.get_decoder_layer( ++ config=cfg, relative_embedding=rel_emb, ++ name=f'layers_{lyr}', original_cls=DecoderLayer) ++ y = decoder_lyr(y, ++ encoded, ++ attention_mask=decoder_mask, ++ encoder_decoder_mask=encoder_decoder_mask, ++ deterministic=deterministic, ++ decode=decode, ++ max_decode_length=max_decode_length) + + y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) + y = nn.Dropout( +@@ -277,6 +314,11 @@ class Decoder(nn.Module): + kernel_axes=('embed', 'vocab'), + name='logits_dense')( + y) ++ ++ if (cfg.transpose_batch_sequence and output_format is SeqDataFormat.BATCH_SEQ_HIDDEN) or \ ++ (not cfg.transpose_batch_sequence and output_format is SeqDataFormat.SEQ_BATCH_HIDDEN): ++ # [length, batch, vocab_size] -> [batch, length, vocab_size] ++ logits = logits.transpose((1, 0, 2)) + return logits + + +@@ -285,7 +327,7 @@ class Transformer(nn.Module): + config: T5Config + + def setup(self): +- cfg = self.config ++ cfg = TransformerEngineHelper.get_t5x_config(self.config) + self.shared_embedding = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, +@@ -301,7 +343,8 @@ class Transformer(nn.Module): + def encode(self, + encoder_input_tokens, + encoder_segment_ids=None, +- enable_dropout=True): ++ enable_dropout=True, ++ output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): + """Applies Transformer encoder-branch on the inputs.""" + cfg = self.config + assert encoder_input_tokens.ndim == 2 # (batch, len) +@@ -319,8 +362,11 @@ class Transformer(nn.Module): + jnp.equal, + dtype=cfg.dtype)) + ++ encoder_mask = TransformerEngineHelper.get_attn_mask(encoder_mask) ++ + return self.encoder( +- encoder_input_tokens, encoder_mask, deterministic=not enable_dropout) ++ encoder_input_tokens, encoder_mask, deterministic=not enable_dropout, ++ output_format=output_format) + + def decode( + self, +@@ -333,7 +379,9 @@ class Transformer(nn.Module): + decoder_positions=None, + enable_dropout=True, + decode=False, +- max_decode_length=None): ++ max_decode_length=None, ++ encoded_format=SeqDataFormat.BATCH_SEQ_HIDDEN, ++ output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): + """Applies Transformer decoder-branch on encoded-input and target.""" + cfg = self.config + +@@ -369,6 +417,9 @@ class Transformer(nn.Module): + jnp.equal, + dtype=cfg.dtype)) + ++ decoder_mask = decoder_mask if decode else TransformerEngineHelper.get_attn_mask(decoder_mask) ++ encoder_decoder_mask = TransformerEngineHelper.get_attn_mask(encoder_decoder_mask) ++ + logits = self.decoder( + encoded, + decoder_input_tokens=decoder_input_tokens, +@@ -377,7 +428,9 @@ class Transformer(nn.Module): + encoder_decoder_mask=encoder_decoder_mask, + deterministic=not enable_dropout, + decode=decode, +- max_decode_length=max_decode_length) ++ max_decode_length=max_decode_length, ++ encoded_format=encoded_format, ++ output_format=output_format) + return logits + + def __call__(self, +@@ -412,10 +465,15 @@ class Transformer(nn.Module): + Returns: + logits array from full transformer. + """ ++ cfg = TransformerEngineHelper.get_t5x_config(self.config) ++ encoded_format = SeqDataFormat.BATCH_SEQ_HIDDEN ++ if cfg.transpose_batch_sequence: ++ encoded_format = SeqDataFormat.SEQ_BATCH_HIDDEN + encoded = self.encode( + encoder_input_tokens, + encoder_segment_ids=encoder_segment_ids, +- enable_dropout=enable_dropout) ++ enable_dropout=enable_dropout, ++ output_format=encoded_format) + + return self.decode( + encoded, +@@ -426,4 +484,5 @@ class Transformer(nn.Module): + decoder_segment_ids=decoder_segment_ids, + decoder_positions=decoder_positions, + enable_dropout=enable_dropout, +- decode=decode) ++ decode=decode, ++ encoded_format=encoded_format) +diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin +index 12e2ee6..901bd03 100644 +--- a/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin ++++ b/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin +@@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" + MIXTURE_OR_TASK_NAME = "glue_mnli_v2" + TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} + TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. ++RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist + DROPOUT_RATE = 0.1 + INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000" + # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained +diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin +index 87b896f..8f49a6f 100644 +--- a/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin ++++ b/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin +@@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" + MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" + TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} + TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. ++RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist + DROPOUT_RATE = 0.1 + INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000" + # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained +diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin +index 7f600f1..5391f0e 100644 +--- a/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin ++++ b/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin +@@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" + MIXTURE_OR_TASK_NAME = "glue_mnli_v2" + TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} + TRAIN_STEPS = 1_015_001 # 1000000 pre-trained steps + 15000 fine-tuning steps. ++RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist + DROPOUT_RATE = 0.1 + INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" + # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained +diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin +index ba5af03..c2b2797 100644 +--- a/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin ++++ b/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin +@@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" + MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" + TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} + TRAIN_STEPS = 1_015_001 # 1000000 pre-trained steps + 15000 fine-tuning steps. ++RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist + DROPOUT_RATE = 0.1 + INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" + # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained +diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin +index 7a58562..75e0fe5 100644 +--- a/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin ++++ b/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin +@@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" + MIXTURE_OR_TASK_NAME = "glue_mnli_v2" + TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} + TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. ++RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist + DROPOUT_RATE = 0.1 + INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000" + # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained +diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin +index 6a4c7b2..a1f50db 100644 +--- a/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin ++++ b/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin +@@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" + MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" + TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} + TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. ++RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist + DROPOUT_RATE = 0.1 + INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000" + # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained +diff --git a/t5x/models.py b/t5x/models.py +index f31d39f..fcd9fd5 100644 +--- a/t5x/models.py ++++ b/t5x/models.py +@@ -43,6 +43,8 @@ from t5x import optimizers + import tensorflow as tf + import typing_extensions + ++from t5x.te_helper import TransformerEngineHelper ++ + # Remove _ShardedDeviceArray when users of t5x have their types updated + _ShardedDeviceArray = Any + Array = Union[np.ndarray, jnp.ndarray, _ShardedDeviceArray, tf.Tensor] +@@ -135,6 +137,7 @@ class BaseModel(abc.ABC): + params: PyTree, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.Array], ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, MetricsMap]: + """Computes loss and metrics. + +@@ -155,6 +158,7 @@ class BaseModel(abc.ABC): + self, + params: PyTree, + batch: Mapping[str, jnp.ndarray], ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, MetricsMap]: + """Computes loss and metrics during the evaluation. + +@@ -172,6 +176,7 @@ class BaseModel(abc.ABC): + params=params, + batch=batch, + dropout_rng=None, ++ flax_mutables=flax_mutables, + ) + + def predict_batch( +@@ -179,6 +184,7 @@ class BaseModel(abc.ABC): + params: PyTree, + batch: Mapping[str, jnp.ndarray], + rng: Optional[jax.Array] = None, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> jnp.ndarray: + """Predicts a batch of outputs from the model. + +@@ -190,7 +196,7 @@ class BaseModel(abc.ABC): + Returns: + The model predictions. + """ +- return self.predict_batch_with_aux(params=params, batch=batch, rng=rng)[0] ++ return self.predict_batch_with_aux(params=params, batch=batch, rng=rng, flax_mutables=flax_mutables)[0] + + @abc.abstractmethod + def predict_batch_with_aux( +@@ -198,6 +204,7 @@ class BaseModel(abc.ABC): + params: PyTree, + batch: Mapping[str, jnp.ndarray], + rng: Optional[jax.Array] = None, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Predict a batch from the modelwith auxiliary outputs. + +@@ -218,6 +225,7 @@ class BaseModel(abc.ABC): + params: PyTree, + batch: Mapping[str, jnp.ndarray], + return_intermediates: bool = False, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> jnp.ndarray: + """Computes scores for batch.""" + pass +@@ -281,6 +289,7 @@ class BaseTransformerModel(BaseModel): + params: PyTree, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.Array] = None, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> jnp.ndarray: + """Computes logits via a forward pass of the model.""" + pass +@@ -290,9 +299,11 @@ class BaseTransformerModel(BaseModel): + params: PyTree, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.Array], ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, MetricsMap]: + """Loss function used for training with a cross-entropy loss.""" +- logits = self._compute_logits(params, batch, dropout_rng) ++ logits = self._compute_logits(params, batch, dropout_rng, ++ flax_mutables=flax_mutables) + + loss_normalizing_factor: Optional[ + Union[float, int, str, losses.SpecialLossNormalizingFactor] +@@ -336,6 +347,31 @@ class BaseTransformerModel(BaseModel): + ) + return loss, metrics + ++ def eval_fn( ++ self, ++ params: PyTreeDef, ++ batch: Mapping[str, jnp.ndarray], ++ flax_mutables: Optional[PyTreeDef] = None, ++ ) -> Tuple[jnp.ndarray, MetricsMap]: ++ """Computes loss and metrics during the evaluation. ++ ++ Args: ++ params: model parameters. ++ batch: a batch of inputs. ++ ++ Returns: ++ loss: the loss computed for the given inputs and parameters. ++ aux: ++ weight_sum: sum of the per-token weights applied to the loss. ++ metrics: a mapping of metrics computed for this batch. ++ """ ++ return self.loss_fn( ++ params=params, ++ batch=batch, ++ dropout_rng=None, ++ flax_mutables=flax_mutables, ++ ) ++ + def _compute_metrics( + self, + logits: jnp.ndarray, +@@ -461,15 +497,18 @@ class EncoderDecoderModel(BaseTransformerModel): + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.Array] = None, + mutable: flax_scope.CollectionFilter = False, +- other_variables: Optional[PyTree] = None, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """Computes logits via a forward pass of `self.module_cls`.""" + # Dropout is provided only for the training mode. + rngs = {'dropout': dropout_rng} if dropout_rng is not None else None +- if other_variables is None: +- other_variables = {} ++ if flax_mutables is None: ++ flax_mutables = {} + return self.module.apply( +- {'params': params, **other_variables}, ++ { ++ 'params': params, ++ **flax_mutables, ++ }, + batch['encoder_input_tokens'], + batch['decoder_input_tokens'], + batch['decoder_target_tokens'], +@@ -490,17 +529,25 @@ class EncoderDecoderModel(BaseTransformerModel): + encoded_inputs: jnp.ndarray, + raw_inputs: jnp.ndarray, + max_decode_length: int, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Token slice to logits from decoder model.""" + flat_ids = decoding_state.cur_token + flat_cache = decoding_state.cache + ++ if flax_mutables is None: ++ flax_mutables = {} ++ + # flat_ids: [batch * beam, seq_len=1] + # cache is expanded inside beam_search to become flat_cache + # flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len] + # flat_logits: [batch * beam, seq_len=1, vocab] + flat_logits, new_vars = self.module.apply( +- {'params': params, 'cache': flat_cache}, ++ { ++ 'params': params, ++ 'cache': flat_cache, ++ **flax_mutables, ++ }, + encoded_inputs, + raw_inputs, # only needed for encoder padding mask + flat_ids, +@@ -523,6 +570,7 @@ class EncoderDecoderModel(BaseTransformerModel): + encoder_input_tokens: jnp.ndarray, + decoder_input_tokens: jnp.ndarray, + prefill_decoder_prompt: bool = False, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[PyTree, Optional[jnp.ndarray]]: + """Initialize the key/value cache, with optional prompt. + +@@ -541,8 +589,14 @@ class EncoderDecoderModel(BaseTransformerModel): + initial_index: The index of the next position following prefill or None if + `prefill_decoder_prompt` is False. + """ ++ if flax_mutables is None: ++ flax_mutables = {} ++ del encoded_inputs + _, initial_variables = self.module.apply( +- {'params': params}, ++ { ++ 'params': params, ++ **flax_mutables, ++ }, + encoder_input_tokens=jnp.ones_like(encoder_input_tokens), + decoder_input_tokens=jnp.ones_like(decoder_input_tokens), + decoder_target_tokens=jnp.ones_like(decoder_input_tokens), +@@ -583,6 +637,24 @@ class EncoderDecoderModel(BaseTransformerModel): + + return cache, inputs_lengths + ++ def predict_batch(self, ++ params: PyTreeDef, ++ batch: Mapping[str, jnp.ndarray], ++ rng: Optional[jax.random.KeyArray] = None, ++ flax_mutables: Optional[PyTreeDef] = None) -> jnp.ndarray: ++ """Predicts a batch of outputs from the model. ++ ++ Args: ++ params: model parameters. ++ batch: a batch of inputs. ++ rng: an optional RNG to use during prediction (e.g., for decoding). ++ ++ Returns: ++ The model predictions. ++ """ ++ return self.predict_batch_with_aux(params=params, batch=batch, ++ rng=rng, flax_mutables=flax_mutables)[0] ++ + def predict_batch_with_aux( + self, + params: PyTree, +@@ -592,6 +664,7 @@ class EncoderDecoderModel(BaseTransformerModel): + return_all_decodes: bool = None, + num_decodes: int = None, # pytype:disable=annotation-type-mismatch + prompt_with_targets: bool = False, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Predict with fast decoding beam search on a batch. + +@@ -648,11 +721,30 @@ class EncoderDecoderModel(BaseTransformerModel): + return_all_decodes = self._default_decoder_params.return_all_decodes + if num_decodes is None: + num_decodes = self._default_decoder_params.num_decodes ++ if flax_mutables is None: ++ flax_mutables = {} + + # [batch, input_len] + encoder_input_tokens = batch['encoder_input_tokens'] + decoder_input_tokens = batch['decoder_input_tokens'] + ++ # Prepare transformer fast-decoder call for beam search: for beam search, we ++ # need to set up our decoder model to handle a batch size equal to ++ # batch_size * num_decodes, where each batch item's data is expanded ++ # in-place rather than tiled. ++ # i.e. if we denote each batch element subtensor as el[n]: ++ # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] ++ # [batch * num_decodes, input_len, emb_dim] ++ encoded_inputs = decoding.flat_batch_beam_expand( ++ self.module.apply( ++ {'params': params, **flax_mutables}, ++ encoder_input_tokens, ++ enable_dropout=False, ++ method=self.module.encode, ++ ), ++ num_decodes, ++ ) ++ + # `decoder_prompt_inputs` is initialized from the batch's + # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop + # after the prompt by matching to `output_vocabulary.eos_id`. +@@ -693,6 +785,7 @@ class EncoderDecoderModel(BaseTransformerModel): + encoder_input_tokens=encoder_input_tokens, + decoder_input_tokens=decoder_prompt_inputs, + prefill_decoder_prompt=prefill_decoder_prompt, ++ flax_mutables=flax_mutables, + ) + + # Prepare transformer fast-decoder call for beam search: for beam search, we +@@ -714,6 +807,7 @@ class EncoderDecoderModel(BaseTransformerModel): + encoder_input_tokens, num_decodes + ), + max_decode_length=decoder_input_tokens.shape[1], ++ flax_mutables=flax_mutables, + ) + + if decoder_params is None: +@@ -739,6 +833,7 @@ class EncoderDecoderModel(BaseTransformerModel): + # decodes: [batch, num_decodes, max_decode_len + 1] + # scores: [batch, num_decodes] + scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers ++ cfg = TransformerEngineHelper.get_t5x_config(self.module.config) + + if 'eos_id' not in decoder_params: + decoder_params['eos_id'] = self.output_vocabulary.eos_id +@@ -747,8 +842,8 @@ class EncoderDecoderModel(BaseTransformerModel): + cache=cache, + tokens_to_logits=tokens_ids_to_logits, + num_decodes=num_decodes, +- cache_offset=1 if scanned else 0, +- **decoder_params, ++ cache_offset=1 if (scanned or cfg.transpose_batch_sequence) else 0, ++ **decoder_params + ) + + # Beam search returns [n_batch, n_beam, n_length] with beam dimension sorted +@@ -764,6 +859,7 @@ class EncoderDecoderModel(BaseTransformerModel): + params: PyTree, + batch: Mapping[str, jnp.ndarray], + return_intermediates: bool = False, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]: + """Compute log likelihood score on a batch.""" + weights = batch['decoder_loss_weights'] +@@ -771,7 +867,8 @@ class EncoderDecoderModel(BaseTransformerModel): + + if return_intermediates: + logits, modified_variables = self._compute_logits( +- params=params, batch=batch, mutable=['intermediates'] ++ params=params, batch=batch, mutable=['intermediates'], ++ flax_mutables=flax_mutables, + ) + + # Inside self.module, we called nn.Module.sow to track various +@@ -789,7 +886,7 @@ class EncoderDecoderModel(BaseTransformerModel): + # `intermediates` should be tuples tracking all instantiations of a value. + # These values each have just one instantiation, hence singletons. + else: +- logits = self._compute_logits(params, batch) # type: jnp.ndarray # pytype: disable=annotation-type-mismatch # jax-ndarray ++ logits = self._compute_logits(params, batch, flax_mutables=flax_mutables) # type: jnp.ndarray # pytype: disable=annotation-type-mismatch # jax-ndarray + + # Purposefully don't use config.z_loss because that term is for training + # stability and shouldn't affect our reported scores. +@@ -900,16 +997,19 @@ class DecoderOnlyModel(BaseTransformerModel): + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.Array] = None, + mutable: flax_scope.CollectionFilter = False, +- other_variables: Optional[PyTree] = None, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> jnp.ndarray: + """Computes logits via a forward pass of `self.module`.""" + rngs = {'dropout': dropout_rng} if dropout_rng is not None else None + decoder_causal_attention = self._get_decoder_causal_attention(batch) +- if other_variables is None: +- other_variables = {} ++ if flax_mutables is None: ++ flax_mutables = {} + + return self.module.apply( +- {'params': params, **other_variables}, ++ { ++ 'params': params, ++ **flax_mutables, ++ }, + batch['decoder_input_tokens'], + batch['decoder_target_tokens'], + decoder_segment_ids=batch.get('decoder_segment_ids', None), +@@ -926,6 +1026,7 @@ class DecoderOnlyModel(BaseTransformerModel): + decoding_state: decoding.DecodingState, + params: PyTree, + max_decode_length: int, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Token slice to logits from decoder model.""" + flat_ids = decoding_state.cur_token +@@ -936,7 +1037,11 @@ class DecoderOnlyModel(BaseTransformerModel): + # flat_cache['cache_index']: [batch] + # flat_logits: [batch, seq_len=1, vocab] + flat_logits, new_vars = self.module.apply( +- {'params': params, 'cache': flat_cache}, ++ { ++ 'params': params, ++ 'cache': flat_cache, ++ **flax_mutables, ++ }, + flat_ids, + flat_ids, + enable_dropout=False, +@@ -954,6 +1059,7 @@ class DecoderOnlyModel(BaseTransformerModel): + params: PyTree, + batch: Mapping[str, jnp.ndarray], + return_intermediates: bool = False, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> jnp.ndarray: + """Compute log likelihood score on a batch.""" + +@@ -966,6 +1072,7 @@ class DecoderOnlyModel(BaseTransformerModel): + batch=batch, + dropout_rng=None, + mutable=['intermediates'], ++ flax_mutables=flax_mutables, + ) + + # Inside self.module, we called nn.Module.sow to track various +@@ -984,7 +1091,7 @@ class DecoderOnlyModel(BaseTransformerModel): + # These values each have just one instantiation, hence singletons. + else: + logits = self._compute_logits( +- params=params, batch=batch, dropout_rng=None ++ params=params, batch=batch, dropout_rng=None, flax_mutables=flax_mutables, + ) + + token_scores = ( +@@ -1012,6 +1119,7 @@ class DecoderOnlyModel(BaseTransformerModel): + params: PyTree, + inputs: jnp.ndarray, + causal_attention_mask: jnp.ndarray, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[PyTree, jnp.ndarray]: + """Compute the key/value cache on the input prompt. + +@@ -1025,12 +1133,17 @@ class DecoderOnlyModel(BaseTransformerModel): + cache: The prefilled cache. + initial_index: The index of the next position following prefill. + """ ++ if flax_mutables is None: ++ flax_mutables = {} + # The lengths of the inputs match the number of non-padding positions, + # excluding the initial BOS. + inputs_lengths = jnp.sum(inputs[:, 1:] != 0, axis=-1) + + _, initial_variables = self.module.apply( +- {'params': params}, ++ { ++ 'params': params, ++ **flax_mutables, ++ }, + jnp.ones_like(inputs), + jnp.ones_like(inputs), + enable_dropout=False, +@@ -1068,7 +1181,11 @@ class DecoderOnlyModel(BaseTransformerModel): + ) + + _, variables_with_cache = self.module.apply( +- {'params': params, 'cache': cache}, ++ { ++ 'params': params, ++ 'cache': cache, ++ **flax_mutables, ++ }, + decoder_input_tokens=inputs, + # Use the `decoder_causal_attention`, which has 1 for all input + # positions, including the BOS token, as the targets so when the +@@ -1095,6 +1212,7 @@ class DecoderOnlyModel(BaseTransformerModel): + return_all_decodes: bool = False, + num_decodes: int = 1, + decoder_params: Optional[MutableMapping[str, Any]] = None, ++ flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Predict with prefix. + +@@ -1193,7 +1311,7 @@ class DecoderOnlyModel(BaseTransformerModel): + inputs = batch['decoder_input_tokens'] * batch['decoder_causal_attention'] + + prefilled_cache, initial_index = self._compute_kv_cache( +- params, inputs, batch['decoder_causal_attention'] ++ params, inputs, batch['decoder_causal_attention'], flax_mutables=flax_mutables, + ) + + target_shape = batch['decoder_input_tokens'].shape +@@ -1203,6 +1321,7 @@ class DecoderOnlyModel(BaseTransformerModel): + self._compute_logits_from_slice, + params=params, + max_decode_length=max_decode_length, ++ flax_mutables=flax_mutables, + ) + + if decoder_params is None: +diff --git a/t5x/partitioning.py b/t5x/partitioning.py +index 910f666..2ae6e7a 100644 +--- a/t5x/partitioning.py ++++ b/t5x/partitioning.py +@@ -34,6 +34,7 @@ from jax.sharding import Mesh + from jax.sharding import PartitionSpec + import numpy as np + from t5x import train_state as train_state_lib ++from t5x.te_helper import TransformerEngineHelper + + JaxDevice = jax.Device + TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores). +@@ -522,6 +523,8 @@ def standard_logical_axis_rules( + if additional_rules: + rules.extend(additional_rules) + ++ rules = TransformerEngineHelper.extend_logical_axis_rules(rules) ++ + return rules + + +diff --git a/t5x/te_helper.py b/t5x/te_helper.py +new file mode 100644 +index 0000000..fb5f48f +--- /dev/null ++++ b/t5x/te_helper.py +@@ -0,0 +1,284 @@ ++# Copyright 2023 The T5X Authors. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++from absl import logging ++from contextlib import contextmanager ++import gin ++import jax ++ ++try: ++ from transformer_engine.common.recipe import DelayedScaling ++ from transformer_engine.common.recipe import Format as FP8Format ++ import transformer_engine.jax as te ++ _IS_TRANSFORMER_ENGINE_INSTALLED = True ++ ++except ModuleNotFoundError as e: ++ _IS_TRANSFORMER_ENGINE_INSTALLED = False ++ ++ ++def _canonicalize_fp8_format(fp8_format): ++ if not _IS_TRANSFORMER_ENGINE_INSTALLED: ++ return None ++ ++ fp8_format = fp8_format.lower() ++ if fp8_format in ['fp8_e4m3', 'fp8e4m3', 'e4m3']: ++ return FP8Format.E4M3 ++ if fp8_format in ['fp8_e5m2', 'fp8e5m2', 'e5m2']: ++ return FP8Format.E5M2 ++ if fp8_format in ['fp8_hybrid', 'fp8hybrid', 'hybrid']: ++ return FP8Format.HYBRID ++ raise ValueError('fp8_format must be one of [fp8_e4m3, fp8_e5m2, fp8_hybrid]' ++ f'but the value is {fp8_format}') ++ ++@gin.configurable ++class TransformerEngineConfig: ++ def __init__(self, enabled=False, fp8_format='fp8_hybrid', margin=0., amax_history_len=1024): ++ assert (_IS_TRANSFORMER_ENGINE_INSTALLED or (not enabled)), \ ++ 'Attempt to run transformer engine FP8 without installing transformer_engine.' ++ ++ self.enabled = enabled ++ self.fp8_format = _canonicalize_fp8_format(fp8_format) ++ self.margin = margin ++ self.amax_history_len = amax_history_len ++ ++ def __str__(self): ++ return f"TransformerEngineConfig: enabled:{self.enabled}," \ ++ f" fp8_format: {self.fp8_format}, margin: {self.margin}," \ ++ f" amax_history_len: {self.amax_history_len}." ++ ++ ++class TransformerEngineHelperBase: ++ ++ @staticmethod ++ def is_fp8_enabled(): ++ raise NotImplementedError ++ ++ @staticmethod ++ @contextmanager ++ def fp8_autocast(te_config, dp_mesh_axis=None, tp_mesh_axis=None): ++ raise NotImplementedError ++ ++ @staticmethod ++ def extend_logical_axis_rules(rules): ++ raise NotImplementedError ++ ++ @staticmethod ++ def update_fp8_metas(grad_accum, flax_mutables): ++ raise NotImplementedError ++ ++ @staticmethod ++ def check_dataset_cfg(config): ++ raise NotImplementedError ++ ++ @staticmethod ++ def get_t5x_config(config): ++ raise NotImplementedError ++ ++ @staticmethod ++ def get_attn_mask(mask): ++ raise NotImplementedError ++ ++ @staticmethod ++ def get_encoder_layer(config, relative_embedding, name, original_cls): ++ raise NotImplementedError ++ ++ @staticmethod ++ def get_decoder_layer(config, relative_embedding, name, original_cls): ++ raise NotImplementedError ++ ++ ++class TENotInstalledHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def is_fp8_enabled(): ++ return False ++ ++ @staticmethod ++ @contextmanager ++ def fp8_autocast(te_config, dp_mesh_axis=None, tp_mesh_axis=None): ++ try: ++ yield ++ finally: ++ pass ++ ++ @staticmethod ++ def extend_logical_axis_rules(rules): ++ return rules ++ ++ @staticmethod ++ def update_fp8_metas(grad_accum, flax_mutables): ++ return flax_mutables ++ ++ @staticmethod ++ def check_dataset_cfg(config): ++ pass ++ ++ @staticmethod ++ def get_t5x_config(config): ++ assert not config.transpose_batch_sequence, \ ++ "Only allow transpose_batch_sequence when Transformer Engine is installed." ++ return config ++ ++ @staticmethod ++ def get_attn_mask(mask): ++ return mask ++ ++ @staticmethod ++ def get_encoder_layer(config, relative_embedding, name, original_cls): ++ return original_cls(config=config, ++ relative_embedding=relative_embedding, name=name) ++ ++ @staticmethod ++ def get_decoder_layer(config, relative_embedding, name, original_cls): ++ return original_cls(config=config, ++ relative_embedding=relative_embedding, name=name) ++ ++ ++class TEInstalledHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def is_fp8_enabled(): ++ return te.fp8.FP8Helper.is_fp8_enabled() ++ ++ @staticmethod ++ @contextmanager ++ def fp8_autocast(te_config, dp_mesh_axis="data", tp_mesh_axis="model"): ++ delay_scaling = DelayedScaling(margin=te_config.margin, ++ fp8_format=te_config.fp8_format, ++ amax_history_len=te_config.amax_history_len, ++ amax_compute_algo="max") ++ try: ++ with te.fp8_autocast(enabled=te_config.enabled, fp8_recipe=delay_scaling, ++ sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis)): ++ yield ++ finally: ++ pass ++ ++ @staticmethod ++ def extend_logical_axis_rules(rules): ++ # Apply fp8_autocast to correctly set sharding_resource up. ++ with TEInstalledHelper.fp8_autocast(TransformerEngineConfig()): ++ return te.extend_logical_axis_rules(rules) ++ ++ @staticmethod ++ def update_fp8_metas(grad_accum, flax_mutables): ++ update_coll = te.update_collections(grad_accum, flax_mutables) ++ # As the suggestion of FP8 training, updating FP8 scales as frequent as possible. ++ update_coll = te.update_fp8_metas(update_coll) ++ return update_coll ++ ++ @staticmethod ++ def check_dataset_cfg(config): ++ assert not config.pack, \ ++ "Transformer Engine does not support dataset.packing, please turn it off." ++ ++ @staticmethod ++ def get_t5x_config(config): ++ return config ++ ++ @staticmethod ++ def get_attn_mask(mask): ++ # Invert T5X's mask by 0->1, and 1->0 ++ mask_ = mask ++ mask_ = 1 - mask_.astype(jax.numpy.uint8) ++ return mask_ ++ ++ @staticmethod ++ def get_encoder_layer(config, relative_embedding, name, original_cls): ++ hidden_dropout_dims = (-3,) if config.transpose_batch_sequence else(-2,) ++ return te.TransformerLayer( ++ hidden_size=config.num_heads*config.head_dim, ++ mlp_hidden_size=config.mlp_dim, ++ layernorm_type="rmsnorm", ++ num_attention_heads=config.num_heads, ++ hidden_dropout=config.dropout_rate, ++ hidden_dropout_dims = hidden_dropout_dims, ++ attention_dropout=config.dropout_rate, ++ mlp_activations=config.mlp_activations, ++ transpose_batch_sequence=config.transpose_batch_sequence, ++ float32_attention_logits=config.float32_attention_logits, ++ scale_attn_logits=config.scale_attn_logits, ++ scaled_query_init=True, ++ fuse_qkv_params=config.fuse_qkv_params, ++ relative_embedding=relative_embedding, ++ dtype=config.dtype, layer_type=te.TransformerLayerType.ENCODER, name=name) ++ ++ @staticmethod ++ def get_decoder_layer(config, relative_embedding, name, original_cls): ++ hidden_dropout_dims = (-3,) if config.transpose_batch_sequence else(-2,) ++ return te.TransformerLayer( ++ hidden_size=config.num_heads*config.head_dim, ++ mlp_hidden_size=config.mlp_dim, ++ layernorm_type="rmsnorm", ++ num_attention_heads=config.num_heads, ++ hidden_dropout=config.dropout_rate, ++ hidden_dropout_dims = hidden_dropout_dims, ++ attention_dropout=config.dropout_rate, ++ mlp_activations=config.mlp_activations, ++ transpose_batch_sequence=config.transpose_batch_sequence, ++ float32_attention_logits=config.float32_attention_logits, ++ scale_attn_logits=config.scale_attn_logits, ++ scaled_query_init=True, ++ fuse_qkv_params=config.fuse_qkv_params, ++ relative_embedding=relative_embedding, ++ dtype=config.dtype, layer_type=te.TransformerLayerType.DECODER, name=name) ++ ++ ++class TransformerEngineHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def get_helper(): ++ if _IS_TRANSFORMER_ENGINE_INSTALLED: ++ return TEInstalledHelper ++ return TENotInstalledHelper ++ ++ @staticmethod ++ def is_fp8_enabled(): ++ return TransformerEngineHelper.get_helper().is_fp8_enabled() ++ ++ @staticmethod ++ @contextmanager ++ def fp8_autocast(te_config, dp_mesh_axis="data", tp_mesh_axis="model"): ++ try: ++ with TransformerEngineHelper.get_helper().fp8_autocast(te_config, dp_mesh_axis, tp_mesh_axis): ++ yield ++ finally: ++ pass ++ ++ @staticmethod ++ def extend_logical_axis_rules(rules): ++ return TransformerEngineHelper.get_helper().extend_logical_axis_rules(rules) ++ ++ @staticmethod ++ def update_fp8_metas(grad_accum, flax_mutables): ++ return TransformerEngineHelper.get_helper().update_fp8_metas(grad_accum, flax_mutables) ++ ++ @staticmethod ++ def check_dataset_cfg(config): ++ return TransformerEngineHelper.get_helper().check_dataset_cfg(config) ++ ++ @staticmethod ++ def get_t5x_config(config): ++ return TransformerEngineHelper.get_helper().get_t5x_config(config) ++ ++ @staticmethod ++ def get_attn_mask(mask): ++ return TransformerEngineHelper.get_helper().get_attn_mask(mask) ++ ++ @staticmethod ++ def get_encoder_layer(config, relative_embedding, name, original_cls): ++ return TransformerEngineHelper.get_helper().get_encoder_layer(config, relative_embedding, name, original_cls) ++ ++ @staticmethod ++ def get_decoder_layer(config, relative_embedding, name, original_cls): ++ return TransformerEngineHelper.get_helper().get_decoder_layer(config, relative_embedding, name, original_cls) +diff --git a/t5x/train.py b/t5x/train.py +index d6bfd6a..e7ee77d 100644 +--- a/t5x/train.py ++++ b/t5x/train.py +@@ -45,6 +45,8 @@ from t5x import partitioning + from t5x import train_state as train_state_lib + from t5x import trainer as trainer_lib + from t5x import utils ++from t5x.te_helper import TransformerEngineConfig, TransformerEngineHelper ++import atexit + import tensorflow as tf + # pylint:enable=g-import-not-at-top + +@@ -107,6 +109,7 @@ def train( + total_steps: int, + eval_steps: int, + eval_period: int, ++ te_config_cls: Type[TransformerEngineConfig] = TransformerEngineConfig, + stats_period: Optional[int] = None, + random_seed: Optional[int], + use_hardware_rng: bool = False, +@@ -127,6 +130,7 @@ def train( + Callable[[utils.DatasetConfig, models.BaseTransformerModel], None] + ] = utils.verify_matching_vocabs, + gc_period: int = 0, ++ reset_state_after: Optional[int] = None, + ) -> Tuple[int, train_state_lib.TrainState]: + """Train function. + +@@ -188,6 +192,12 @@ def train( + instance. Should raise an exception on error. + gc_period: The number of train steps between runs of the garbage collector. + If 0, the garbage collector will run at the normal frequency. ++ reset_state_after: Optional number of steps after which to reset the ++ optimizer states and fp8 metadata. In finetuning, you may not want to ++ keep fp8 metas due to the distribution change messing up `amax` ++ statistics. Ex: to reset to finetuning, set this to the number of ++ pretraining steps. Only triggered if the train step on restore equals ++ reset_state_after. Otherwise ignored + + Returns: + The tuple of (last_step, last_train_state). +@@ -214,6 +224,18 @@ def train( + ' unsupported.' + ) + ++ te_config = te_config_cls() ++ logging.info(te_config) ++ ++ # Note(terry): The proper usage of the TE API is to use fp8_autocast as a ++ # contextmanager using the "with" statement. The reason it is done this ++ # way here is to avoid indenting the code. Please refer to TE documentation ++ # for more details. ++ te_ctx_mgr = TransformerEngineHelper.fp8_autocast(te_config) ++ # Register a hook with atexit in case exception raised ++ atexit.register(lambda: te_ctx_mgr.__exit__(None, None, None)) ++ te_ctx_mgr.__enter__() ++ + # Each "epoch" of the training loop should be the min of the eval period, + # checkpoint period or the full training. + # We compute here to ensure that the eval period and checkpoint period are +@@ -274,6 +296,8 @@ def train( + # Initialize datasets + # --------------------------------------------------------------------------- + ++ TransformerEngineHelper.check_dataset_cfg(train_dataset_cfg) ++ + if train_dataset_cfg.seed and not ( + checkpoint_cfg.save and checkpoint_cfg.save.save_dataset + ): +@@ -312,6 +336,7 @@ def train( + input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec) + + if train_eval_dataset_cfg: ++ TransformerEngineHelper.check_dataset_cfg(train_eval_dataset_cfg) + _verify_matching_vocabs(train_eval_dataset_cfg) + train_eval_datasets = train_eval_get_dataset_fn( + train_eval_dataset_cfg, +@@ -410,6 +435,38 @@ def train( + ) + ) + ++ if train_state is not None: ++ # Only triggered if the train step is equal to reset_state_after right after ++ # restore ++ host_step = int(utils.get_local_data(train_state.step)) # pytype: disable=attribute-error ++ if reset_state_after and reset_state_after == host_step: ++ logging.info('Resetting optimizer and fp8 states. Preserving only optimizer targets') ++ assert use_orbax == False, "resetting the states in the train loop is not\ ++ supported with orbax. If you need to do this,\ ++ please delete the optimizer state and fp8 metas\ ++ from the checkpoint folder directly " ++ old_step = train_state.step ++ train_state_initializer = train_state_initializer_cls( ++ optimizer_def=model.optimizer_def, ++ init_fn=model.get_initial_variables, ++ input_shapes=input_shapes, ++ input_types=input_types, ++ partitioner=partitioner) ++ from_scratch_state = train_state_initializer.from_scratch(init_rng) ++ _optimizer = from_scratch_state._optimizer.replace( ++ target=train_state.params) ++ train_state = from_scratch_state.replace(_optimizer=_optimizer) ++ train_state = train_state.replace_step(old_step) ++ ++ checkpoint_manager = utils.LegacyCheckpointManager( ++ save_cfg=checkpoint_cfg.save, ++ restore_cfg=valid_restore_cfg, ++ train_state_shape=train_state_initializer.global_train_state_shape, ++ partitioner=partitioner, ++ ds_iter=train_iter, ++ model_dir=model_dir, ++ use_gda=use_gda) ++ + # Start warming up the input pipeline in the background. This must happen + # after input pipeline checkpoints were restored. + first_batch_ready = train_iter.peek_async() +@@ -459,6 +516,7 @@ def train( + # Init evaluator to set up cached datasets + evaluator = None + if infer_eval_dataset_cfg is not None: ++ TransformerEngineHelper.check_dataset_cfg(infer_eval_dataset_cfg) + evaluator = eval_lib.InferenceEvaluator( + infer_eval_dataset_cfg=infer_eval_dataset_cfg, + inference_evaluator_cls=inference_evaluator_cls, +@@ -800,6 +858,7 @@ def train( + # the same interpreter. + gc.enable() + ++ te_ctx_mgr.__exit__(None, None, None) + return host_step, trainer.train_state + + +diff --git a/t5x/train_state.py b/t5x/train_state.py +index 5ac52af..95cdd99 100644 +--- a/t5x/train_state.py ++++ b/t5x/train_state.py +@@ -222,7 +222,8 @@ class FlaxOptimTrainState(flax.struct.PyTreeNode): + _optimizer=self._optimizer.optimizer_def.derive_logical_axes( + self._optimizer, + flax_partitioning.get_axis_names(self.params_axes)), +- flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes)) ++ flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), ++ flax_mutables_axes=self.flax_mutables_axes) + + + class InferenceState(flax.struct.PyTreeNode): +diff --git a/t5x/trainer.py b/t5x/trainer.py +index 752beb3..4eae811 100644 +--- a/t5x/trainer.py ++++ b/t5x/trainer.py +@@ -43,6 +43,7 @@ from t5x import models + from t5x import partitioning + from t5x import train_state as train_state_lib + from t5x import utils ++from t5x.te_helper import TransformerEngineHelper + import typing_extensions + + +@@ -675,7 +676,7 @@ def accumulate_grads_microbatched( + """ + batch_size = next(iter(batch.values())).shape[0] + +- grad_fn = jax.value_and_grad(model.loss_fn, has_aux=True) ++ grad_fn = jax.value_and_grad(model.loss_fn, argnums=(0, 3), has_aux=True) + + # We assume that the model loss_fn supports flax mutables if and only if + # the train state includes non-empty flax mutables. +@@ -683,19 +684,13 @@ def accumulate_grads_microbatched( + # them and return flax_mutables from `get_initial_variables` and `loss_fn`. + + initial_flax_mutables = ( +- train_state.flax_mutables if train_state.flax_mutables else None ++ train_state.flax_mutables if train_state.flax_mutables else {} + ) + + if num_microbatches is None or num_microbatches <= 1: +- +- if initial_flax_mutables is None: +- (_, metrics), grad_accum = grad_fn(train_state.params, batch, dropout_rng) +- flax_mutables = None +- else: +- (_, (metrics, +- flax_mutables)), grad_accum = grad_fn(train_state.params, batch, +- dropout_rng, +- initial_flax_mutables) ++ (_, metrics), grad_accum = grad_fn(train_state.params, batch, ++ dropout_rng, initial_flax_mutables) ++ flax_mutables=initial_flax_mutables + else: + assert batch_size % num_microbatches == 0, ( + "Batch size isn't divided evenly by num_microbatches.") +@@ -722,13 +717,9 @@ def accumulate_grads_microbatched( + lambda x: partitioning.with_sharding_constraint( # pylint: disable=g-long-lambda + x, data_partition_spec), + mbatch) +- if flax_mutables is None: +- (_, metrics), grad = grad_fn(train_state.params, mbatch, +- sub_dropout_rng) +- else: +- (_, (metrics, flax_mutables)), grad = grad_fn(train_state.params, +- mbatch, sub_dropout_rng, +- flax_mutables) ++ (_, metrics), grad = grad_fn(train_state.params, mbatch, ++ sub_dropout_rng, flax_mutables) ++ + return metrics, grad, flax_mutables + + def per_microbatch_train_step( +@@ -741,7 +732,9 @@ def accumulate_grads_microbatched( + metrics, grad, flax_mutables = metrics_and_grad(loop_cnt, dropout_rng, + flax_mutables) + +- grad_accum = jax.tree_util.tree_map(jnp.add, grad_accum, grad) ++ grad_accum[0] = jax.tree_map(jnp.add, grad_accum[0], grad[0]) ++ grad_accum[1] = TransformerEngineHelper.update_fp8_metas(grad[1], flax_mutables) ++ flax_mutables = grad_accum[1] + metrics = jax.lax.cond(loop_cnt == 0, lambda _: metrics, + lambda _: merge_metrics(prev_metrics, metrics), + None) +@@ -749,8 +742,10 @@ def accumulate_grads_microbatched( + + # Initialize gradient accumulation loop state. + accum_dtype = jnp.float32 +- grad_accum_init = jax.tree_util.tree_map( +- lambda x: jnp.zeros(x.shape, accum_dtype), train_state.params) ++ grad_accum_init = [jax.tree_map(lambda x: jnp.zeros(x.shape, accum_dtype), ++ train_state.params), ++ jax.tree_map(lambda x: jnp.zeros(x.shape, accum_dtype), ++ train_state.flax_mutables)] + initial_metrics_shape, _, _ = jax.eval_shape( + metrics_and_grad, + loop_cnt=0, +@@ -769,6 +764,9 @@ def accumulate_grads_microbatched( + + del new_dropout_rng + ++ grad_accum = (grad_accum[0], ++ TransformerEngineHelper.update_fp8_metas(grad_accum[1], flax_mutables)) ++ + return grad_accum, metrics, flax_mutables + + +@@ -797,9 +795,12 @@ def apply_grads( + """ + if other_state_variables is None: + other_state_variables = {} ++ ++ other_state_variables["flax_mutables"] = FrozenDict(grad_accum[1]) ++ + # Update optimizer using accumulated gradient. + new_train_state = train_state.apply_gradient( +- grad_accum, learning_rate=learning_rate, **other_state_variables) ++ grad_accum[0], learning_rate=learning_rate, **other_state_variables) + metrics["learning_rate"] = clu.metrics.Average.from_model_output( + jnp.asarray([learning_rate])) + metrics["learning_rate/current"] = clu.metrics.LastValue.from_model_output( +-- +2.25.1 + + +From 325f0e38720581c4f2535f3da23f40f425560138 Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Tue, 11 Jul 2023 12:10:33 -0700 +Subject: [PATCH 02/16] UNINSTALL_TE in fine-tuning scripts now defaults to + no-action + +--- + .../gpu/scripts_gpu/multiprocess_ft_frompile.sh | 10 +++------- + 1 file changed, 3 insertions(+), 7 deletions(-) + +diff --git a/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh b/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh +index 388d2ec..135ecf6 100755 +--- a/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh ++++ b/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh +@@ -72,13 +72,9 @@ case $FT_TASK in + ;; + esac + +-case $UNINSTALL_TE in +- 0) +- ;; +- *) +- pip uninstall -y transformer_engine +- ;; +-esac ++if [[ -n "${UNINSTALL_TE:-}" && ${UNINSTALL_TE:-} -ne 0 ]]; then ++ pip uninstall -y transformer_engine ++fi + + # Global batch size + BSIZE=$(( GPUS_PER_NODE * BSIZE_PER_GPU * SLURM_JOB_NUM_NODES / TP_SIZE)) +-- +2.25.1 + + +From f75aa370fe5c82dc8cf2f5a3ef3bdba407b60628 Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Wed, 12 Jul 2023 20:26:28 -0700 +Subject: [PATCH 03/16] remove use_gda from LegacyCheckpointManager in train.py + for fp8 + +--- + t5x/train.py | 3 +-- + 1 file changed, 1 insertion(+), 2 deletions(-) + +diff --git a/t5x/train.py b/t5x/train.py +index e7ee77d..7fe705c 100644 +--- a/t5x/train.py ++++ b/t5x/train.py +@@ -464,8 +464,7 @@ def train( + train_state_shape=train_state_initializer.global_train_state_shape, + partitioner=partitioner, + ds_iter=train_iter, +- model_dir=model_dir, +- use_gda=use_gda) ++ model_dir=model_dir) + + # Start warming up the input pipeline in the background. This must happen + # after input pipeline checkpoints were restored. +-- +2.25.1 + + +From 395cf5bcca9c26f7fdaabca8541aa8368bcb8f91 Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Tue, 18 Jul 2023 15:55:01 -0700 +Subject: [PATCH 04/16] Allow singlenode scripts to tee to stdout for better + indication of training status + +--- + t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh | 2 +- + t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh | 2 +- + 2 files changed, 2 insertions(+), 2 deletions(-) + +diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh +index c27c4a0..73e1139 100755 +--- a/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh ++++ b/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh +@@ -90,5 +90,5 @@ python3 -u ${T5X_DIR}/t5x/train.py \ + --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ + --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ + --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ +- &> \ ++ 2>&1 | tee \ + ${LOG_DIR}/ft_${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log +diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +index def1a1a..0d12f30 100755 +--- a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh ++++ b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +@@ -62,5 +62,5 @@ python3 -u ${T5X_DIR}/t5x/train.py \ + --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ + --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ + --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ +- &> \ ++ 2>&1 | tee \ + ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log +-- +2.25.1 + + +From 7cc636157b1afd8496baea7977876124b205db9e Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Fri, 14 Jul 2023 05:00:58 -0700 +Subject: [PATCH 05/16] Explicit specify self_attn_mask_type + +--- + t5x/te_helper.py | 10 ++++++++-- + 1 file changed, 8 insertions(+), 2 deletions(-) + +diff --git a/t5x/te_helper.py b/t5x/te_helper.py +index fb5f48f..f3750ca 100644 +--- a/t5x/te_helper.py ++++ b/t5x/te_helper.py +@@ -211,7 +211,10 @@ class TEInstalledHelper(TransformerEngineHelperBase): + scaled_query_init=True, + fuse_qkv_params=config.fuse_qkv_params, + relative_embedding=relative_embedding, +- dtype=config.dtype, layer_type=te.TransformerLayerType.ENCODER, name=name) ++ dtype=config.dtype, ++ layer_type=te.TransformerLayerType.ENCODER, ++ self_attn_mask_type='padding', ++ name=name) + + @staticmethod + def get_decoder_layer(config, relative_embedding, name, original_cls): +@@ -231,7 +234,10 @@ class TEInstalledHelper(TransformerEngineHelperBase): + scaled_query_init=True, + fuse_qkv_params=config.fuse_qkv_params, + relative_embedding=relative_embedding, +- dtype=config.dtype, layer_type=te.TransformerLayerType.DECODER, name=name) ++ dtype=config.dtype, ++ layer_type=te.TransformerLayerType.DECODER, ++ self_attn_mask_type='causal', ++ name=name) + + + class TransformerEngineHelper(TransformerEngineHelperBase): +-- +2.25.1 + + +From d17222e423d83ea2faba226fad5f6a2b1bc9307f Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Thu, 3 Aug 2023 14:25:22 -0700 +Subject: [PATCH 06/16] Disables check for packing by the te_helper util since + not all dataset configs use packing (CV/Multimodal) + +--- + t5x/te_helper.py | 2 ++ + 1 file changed, 2 insertions(+) + +diff --git a/t5x/te_helper.py b/t5x/te_helper.py +index f3750ca..f585752 100644 +--- a/t5x/te_helper.py ++++ b/t5x/te_helper.py +@@ -179,6 +179,8 @@ class TEInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def check_dataset_cfg(config): ++ if not hasattr(config, 'pack'): ++ return + assert not config.pack, \ + "Transformer Engine does not support dataset.packing, please turn it off." + +-- +2.25.1 + + +From c0448f9284bf910699dc7ecba202c2f213fa7481 Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Sat, 26 Aug 2023 16:13:11 -0700 +Subject: [PATCH 07/16] Corrected T5x large baselines + +Updated T5x-large MNLI and SQUAD baselines +--- + docs/usage/gpu-usage.md | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/docs/usage/gpu-usage.md b/docs/usage/gpu-usage.md +index a9974e1..660df3a 100644 +--- a/docs/usage/gpu-usage.md ++++ b/docs/usage/gpu-usage.md +@@ -35,7 +35,7 @@ For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2 + | size | GPU | Precision | #GPUs | TP | BS / GPU | Sequences/Sec | Seq/Sec/GPU | Est. Walltime | GPU-days | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | Config | + | ---- | ------------ | --------- | ----- | ----- | -------- | ------------- | ----------- | ------------- | -------- |------------------ | ------------------ | --------------- | ---- | + | [T5-v1.1-small](../t5/t5_1_1/small.gin) | A100 80G SXM | bf16 | 8 | 1 | 256 | ~5712 | 714 | 4.2 days | 33 | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | [pile](../t5/t5_1_1/examples/small_pile_pretrain.gin) +-| [T5-v1.1-large](../t5/t5_1_1/large.gin) | A100 80G SXM | bf16 | 64 | 1 | 32 | ~4853 | 75.8 | 4.8 days | 309 | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) ++| [T5-v1.1-large](../t5/t5_1_1/large.gin) | A100 80G SXM | bf16 | 64 | 1 | 32 | ~4853 | 75.8 | 4.8 days | 309 | 89.23% | 86.12 / 93.21 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) + | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 144 | 1 | 8 | ~3021 | 21.0 | 7.9 days | 1,133 | N/A(perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) + | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 256 | 1 | 8 | ~4322 | 16.9 | 5.5 days | 1,408 | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) + | [T5-v1.1-xxl](../t5/t5_1_1/xxl.gin) | A100 80G SXM | bf16 | 512 | 8 | 36 | ~1887 | 3.69 | 12.6 days | 6,431 |N/A(partial run) | N/A(partial run) | |[pile](../t5/t5_1_1/examples/xxl_pile_pretrain.gin) +-- +2.25.1 + + +From 443df5e4414dcbac5482fc2672e1484a554c1bd2 Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Fri, 8 Sep 2023 15:09:08 -0700 +Subject: [PATCH 08/16] Add t5-large FP8 logs + +--- + docs/usage/gpu-usage.md | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/docs/usage/gpu-usage.md b/docs/usage/gpu-usage.md +index 660df3a..c31094d 100644 +--- a/docs/usage/gpu-usage.md ++++ b/docs/usage/gpu-usage.md +@@ -39,7 +39,7 @@ For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2 + | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 144 | 1 | 8 | ~3021 | 21.0 | 7.9 days | 1,133 | N/A(perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) + | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 256 | 1 | 8 | ~4322 | 16.9 | 5.5 days | 1,408 | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) + | [T5-v1.1-xxl](../t5/t5_1_1/xxl.gin) | A100 80G SXM | bf16 | 512 | 8 | 36 | ~1887 | 3.69 | 12.6 days | 6,431 |N/A(partial run) | N/A(partial run) | |[pile](../t5/t5_1_1/examples/xxl_pile_pretrain.gin) +-| [T5-v1.1-large](../t5/t5_1_1/large.gin) | **H100 80G SXM** | TE-fp8 | 64 | 1 | 32 | ~10156 | **158.7** | **2.3 days** | **147** | 89.1% | 86.36 / 93.5 | |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) ++| [T5-v1.1-large](../t5/t5_1_1/large.gin) | **H100 80G SXM** | TE-fp8 | 64 | 1 | 32 | ~10156 | **158.7** | **2.3 days** | **147** | 89.1% | 86.36 / 93.5 | [log](https://tensorboard.dev/experiment/QJYnDaaBSeuZtYPXXtAG3Q/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) + | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 144 | 1 | 14 | ~7257 | **50.4** | **3.3 days** | **475** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) + | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 256 | 1 | 8 | ~9688 | **37.8** | **2.4 days** | **614** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) + +-- +2.25.1 + + +From 8494e0ba9683f1776500d0a401f29633f6b6130b Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Fri, 20 Oct 2023 14:26:09 +0800 +Subject: [PATCH 09/16] Fix missing fp8_meta_collection in the eval stage. + +--- + t5x/models.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/t5x/models.py b/t5x/models.py +index fcd9fd5..a4276ec 100644 +--- a/t5x/models.py ++++ b/t5x/models.py +@@ -758,7 +758,7 @@ class EncoderDecoderModel(BaseTransformerModel): + decoder_prompt_inputs = jnp.zeros_like(decoder_input_tokens) + + encoded_inputs = self.module.apply( +- {'params': params}, ++ {'params': params, **flax_mutables}, + encoder_input_tokens, + enable_dropout=False, + method=self.module.encode, +-- +2.25.1 + + +From 181087dbc71a268e4d96fbd3adb429ab630f0486 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Fri, 20 Oct 2023 14:48:40 +0800 +Subject: [PATCH 10/16] Remove redundant code. + +--- + t5x/models.py | 17 ----------------- + 1 file changed, 17 deletions(-) + +diff --git a/t5x/models.py b/t5x/models.py +index a4276ec..faeb6d0 100644 +--- a/t5x/models.py ++++ b/t5x/models.py +@@ -728,23 +728,6 @@ class EncoderDecoderModel(BaseTransformerModel): + encoder_input_tokens = batch['encoder_input_tokens'] + decoder_input_tokens = batch['decoder_input_tokens'] + +- # Prepare transformer fast-decoder call for beam search: for beam search, we +- # need to set up our decoder model to handle a batch size equal to +- # batch_size * num_decodes, where each batch item's data is expanded +- # in-place rather than tiled. +- # i.e. if we denote each batch element subtensor as el[n]: +- # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] +- # [batch * num_decodes, input_len, emb_dim] +- encoded_inputs = decoding.flat_batch_beam_expand( +- self.module.apply( +- {'params': params, **flax_mutables}, +- encoder_input_tokens, +- enable_dropout=False, +- method=self.module.encode, +- ), +- num_decodes, +- ) +- + # `decoder_prompt_inputs` is initialized from the batch's + # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop + # after the prompt by matching to `output_vocabulary.eos_id`. +-- +2.25.1 + + +From 4d051703078985ca95b23def87672513dbce9daa Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Fri, 20 Oct 2023 15:05:40 +0800 +Subject: [PATCH 11/16] Fix deprecating warning about TE. + +--- + t5x/te_helper.py | 8 ++++---- + 1 file changed, 4 insertions(+), 4 deletions(-) + +diff --git a/t5x/te_helper.py b/t5x/te_helper.py +index f585752..568f596 100644 +--- a/t5x/te_helper.py ++++ b/t5x/te_helper.py +@@ -198,7 +198,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + @staticmethod + def get_encoder_layer(config, relative_embedding, name, original_cls): + hidden_dropout_dims = (-3,) if config.transpose_batch_sequence else(-2,) +- return te.TransformerLayer( ++ return te.flax.TransformerLayer( + hidden_size=config.num_heads*config.head_dim, + mlp_hidden_size=config.mlp_dim, + layernorm_type="rmsnorm", +@@ -214,14 +214,14 @@ class TEInstalledHelper(TransformerEngineHelperBase): + fuse_qkv_params=config.fuse_qkv_params, + relative_embedding=relative_embedding, + dtype=config.dtype, +- layer_type=te.TransformerLayerType.ENCODER, ++ layer_type=te.flax.TransformerLayerType.ENCODER, + self_attn_mask_type='padding', + name=name) + + @staticmethod + def get_decoder_layer(config, relative_embedding, name, original_cls): + hidden_dropout_dims = (-3,) if config.transpose_batch_sequence else(-2,) +- return te.TransformerLayer( ++ return te.flax.TransformerLayer( + hidden_size=config.num_heads*config.head_dim, + mlp_hidden_size=config.mlp_dim, + layernorm_type="rmsnorm", +@@ -237,7 +237,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + fuse_qkv_params=config.fuse_qkv_params, + relative_embedding=relative_embedding, + dtype=config.dtype, +- layer_type=te.TransformerLayerType.DECODER, ++ layer_type=te.flax.TransformerLayerType.DECODER, + self_attn_mask_type='causal', + name=name) + +-- +2.25.1 + + +From 5fff9dc6557e084378d24fc9c14376331e7f3d3a Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Fri, 27 Oct 2023 09:08:10 -0700 +Subject: [PATCH 12/16] Updates TE api from te.extend_* to te.flax.extend_* + (#7) + +Co-authored-by: NVIDIA +--- + t5x/te_helper.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/t5x/te_helper.py b/t5x/te_helper.py +index 568f596..05c5f6b 100644 +--- a/t5x/te_helper.py ++++ b/t5x/te_helper.py +@@ -168,7 +168,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + def extend_logical_axis_rules(rules): + # Apply fp8_autocast to correctly set sharding_resource up. + with TEInstalledHelper.fp8_autocast(TransformerEngineConfig()): +- return te.extend_logical_axis_rules(rules) ++ return te.flax.extend_logical_axis_rules(rules) + + @staticmethod + def update_fp8_metas(grad_accum, flax_mutables): +-- +2.25.1 + + +From e36be07a36c476545a5bbbb59e5c7a44419deb02 Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Tue, 31 Oct 2023 21:52:22 -0700 +Subject: [PATCH 13/16] Adds ENABLE_TE env var and renames TEConfig.enabled -> + TEConfig.enable_fp8 (#8) + +* Allows ENABLE_TE env var to control whether TE code path is invoked + +* Changes enabled -> enable_fp8 to be more consistent with PAX and avoid confusion with ENABLE_TE + +* Remove UNINSTALL_TE logic since it is no longer required + +--------- + +Co-authored-by: NVIDIA +--- + .../scripts_gpu/multiprocess_ft_frompile.sh | 6 +---- + .../scripts_gpu/multiprocess_pretrain_pile.sh | 2 +- + .../gpu/scripts_gpu/singlenode_ft_frompile.sh | 2 +- + .../scripts_gpu/singlenode_pretrain_pile.sh | 2 +- + t5x/te_helper.py | 22 +++++++++++++------ + 5 files changed, 19 insertions(+), 15 deletions(-) + +diff --git a/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh b/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh +index 135ecf6..cd563ec 100755 +--- a/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh ++++ b/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh +@@ -72,10 +72,6 @@ case $FT_TASK in + ;; + esac + +-if [[ -n "${UNINSTALL_TE:-}" && ${UNINSTALL_TE:-} -ne 0 ]]; then +- pip uninstall -y transformer_engine +-fi +- + # Global batch size + BSIZE=$(( GPUS_PER_NODE * BSIZE_PER_GPU * SLURM_JOB_NUM_NODES / TP_SIZE)) + export GPU_DEVICES=$(seq -s, 0 $((GPUS_PER_NODE - 1)) ) +@@ -105,7 +101,7 @@ python3 -u ${T5X_DIR}/t5x/train.py \ + --gin.train.eval_period=1000 \ + --gin.train.gc_period=2000 \ + --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ +- --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ ++ --gin.te_helper.TransformerEngineConfig.enable_fp8=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ + --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ + --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ +diff --git a/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh b/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh +index f807105..d083540 100755 +--- a/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh ++++ b/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh +@@ -95,7 +95,7 @@ python3 ${T5X_DIR}/t5x/train.py \ + --gin.train.eval_period=1000 \ + --gin.train.gc_period=${TRAIN_STEPS} \ + --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ +- --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ ++ --gin.te_helper.TransformerEngineConfig.enable_fp8=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ + --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ + --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ +diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh +index 73e1139..981fb21 100755 +--- a/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh ++++ b/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh +@@ -86,7 +86,7 @@ python3 -u ${T5X_DIR}/t5x/train.py \ + --gin.train/utils.DatasetConfig.pack=${PACK} \ + --gin.train_eval/utils.DatasetConfig.pack=${PACK} \ + --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ +- --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ ++ --gin.te_helper.TransformerEngineConfig.enable_fp8=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ + --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ + --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ +diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +index 0d12f30..8ca54b0 100755 +--- a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh ++++ b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +@@ -58,7 +58,7 @@ python3 -u ${T5X_DIR}/t5x/train.py \ + --gin.train/utils.DatasetConfig.pack=${PACK} \ + --gin.train_eval/utils.DatasetConfig.pack=${PACK} \ + --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ +- --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ ++ --gin.te_helper.TransformerEngineConfig.enable_fp8=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ + --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ + --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ +diff --git a/t5x/te_helper.py b/t5x/te_helper.py +index 05c5f6b..7657c52 100644 +--- a/t5x/te_helper.py ++++ b/t5x/te_helper.py +@@ -15,16 +15,20 @@ from absl import logging + from contextlib import contextmanager + import gin + import jax ++import os ++ ++logging.set_verbosity(logging.INFO) + + try: + from transformer_engine.common.recipe import DelayedScaling + from transformer_engine.common.recipe import Format as FP8Format + import transformer_engine.jax as te + _IS_TRANSFORMER_ENGINE_INSTALLED = True ++ logging.info('Transformer Engine is installed') + + except ModuleNotFoundError as e: + _IS_TRANSFORMER_ENGINE_INSTALLED = False +- ++ logging.info('Transformer Engine is not installed') + + def _canonicalize_fp8_format(fp8_format): + if not _IS_TRANSFORMER_ENGINE_INSTALLED: +@@ -42,17 +46,17 @@ def _canonicalize_fp8_format(fp8_format): + + @gin.configurable + class TransformerEngineConfig: +- def __init__(self, enabled=False, fp8_format='fp8_hybrid', margin=0., amax_history_len=1024): +- assert (_IS_TRANSFORMER_ENGINE_INSTALLED or (not enabled)), \ ++ def __init__(self, enable_fp8=False, fp8_format='fp8_hybrid', margin=0., amax_history_len=1024): ++ assert (_IS_TRANSFORMER_ENGINE_INSTALLED or (not enable_fp8)), \ + 'Attempt to run transformer engine FP8 without installing transformer_engine.' + +- self.enabled = enabled ++ self.enable_fp8 = enable_fp8 + self.fp8_format = _canonicalize_fp8_format(fp8_format) + self.margin = margin + self.amax_history_len = amax_history_len + + def __str__(self): +- return f"TransformerEngineConfig: enabled:{self.enabled}," \ ++ return f"TransformerEngineConfig: enable_fp8:{self.enable_fp8}," \ + f" fp8_format: {self.fp8_format}, margin: {self.margin}," \ + f" amax_history_len: {self.amax_history_len}." + +@@ -158,7 +162,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + amax_history_len=te_config.amax_history_len, + amax_compute_algo="max") + try: +- with te.fp8_autocast(enabled=te_config.enabled, fp8_recipe=delay_scaling, ++ with te.fp8_autocast(enabled=te_config.enable_fp8, fp8_recipe=delay_scaling, + sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis)): + yield + finally: +@@ -243,10 +247,14 @@ class TEInstalledHelper(TransformerEngineHelperBase): + + + class TransformerEngineHelper(TransformerEngineHelperBase): ++ @staticmethod ++ def is_enabled_te(): ++ enable_te = bool(int((os.environ.get("ENABLE_TE", False)))) ++ return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te) + + @staticmethod + def get_helper(): +- if _IS_TRANSFORMER_ENGINE_INSTALLED: ++ if TransformerEngineHelper.is_enabled_te(): + return TEInstalledHelper + return TENotInstalledHelper + +-- +2.25.1 + + +From 4b8a1ad34509a5f62dea7ce0d3efef87e2a4e9d1 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Tue, 7 Nov 2023 10:57:34 +0800 +Subject: [PATCH 14/16] Adapting to TE/JAX/Custom_partitioning. + +--- + t5x/te_helper.py | 5 ++--- + 1 file changed, 2 insertions(+), 3 deletions(-) + +diff --git a/t5x/te_helper.py b/t5x/te_helper.py +index 7657c52..b064d2b 100644 +--- a/t5x/te_helper.py ++++ b/t5x/te_helper.py +@@ -163,7 +163,8 @@ class TEInstalledHelper(TransformerEngineHelperBase): + amax_compute_algo="max") + try: + with te.fp8_autocast(enabled=te_config.enable_fp8, fp8_recipe=delay_scaling, +- sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis)): ++ mesh_resource=te.MeshResource(dp_resource=dp_mesh_axis, ++ tp_resource=tp_mesh_axis)): + yield + finally: + pass +@@ -177,8 +178,6 @@ class TEInstalledHelper(TransformerEngineHelperBase): + @staticmethod + def update_fp8_metas(grad_accum, flax_mutables): + update_coll = te.update_collections(grad_accum, flax_mutables) +- # As the suggestion of FP8 training, updating FP8 scales as frequent as possible. +- update_coll = te.update_fp8_metas(update_coll) + return update_coll + + @staticmethod +-- +2.25.1 + + +From d08a6847857b6549ed8bfe1441d3742a6a13ac08 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Wed, 22 Nov 2023 13:53:49 +0800 +Subject: [PATCH 15/16] Running Partitioner.compile within Mesh context-manager + +--- + t5x/partitioning.py | 9 ++++++++- + 1 file changed, 8 insertions(+), 1 deletion(-) + +diff --git a/t5x/partitioning.py b/t5x/partitioning.py +index 2ae6e7a..d9ba8bd 100644 +--- a/t5x/partitioning.py ++++ b/t5x/partitioning.py +@@ -783,6 +783,13 @@ class PjittedFnWithContext(PartitionedCallable): + self._logical_axis_rules): + return self._pjitted_fn.lower(*args, **kwargs) + ++ def lower_and_compile(self, *args, **kwargs): ++ with Mesh(self._mesh.devices, ++ self._mesh.axis_names), flax_partitioning.axis_rules( ++ self._logical_axis_rules): ++ return self._pjitted_fn.lower(*args, **kwargs).compile() ++ ++ + + class BasePjitPartitioner(BasePartitioner): + """Partitioner that uses T5X version of jax.pjit.""" +@@ -816,7 +823,7 @@ class BasePjitPartitioner(BasePartitioner): + + def compile(self, partitioned_fn: PjittedFnWithContext, + *args) -> CompiledPartitionedCallable: +- return partitioned_fn.lower(*args).compile() ++ return partitioned_fn.lower_and_compile(*args) + + + class PjitPartitioner(BasePjitPartitioner): +-- +2.25.1 + + +From a08b8bf78e14c288d6d7a66ad17a9bea69044cd7 Mon Sep 17 00:00:00 2001 +From: Terry Kong +Date: Tue, 14 Nov 2023 23:42:35 -0800 +Subject: [PATCH 16/16] Updates multiprocessing scripts to use SLURM output + variables instead of input variables (#9) + +* Update multiprocess scripts + +* No longer need UNINSTALL_TE + +* Removes slurm scripts as the source of truth has moved to rosetta + +* Adds "Finished" message to multiprocess scripts + +* Remove BENCHMARK_ARGS which is no longer used + +* Fix typo in BENCHMARK_MODE and straggling if keyword + +* Address comments +--- + .../scripts_gpu/example_slurm_ft_frompile.sub | 98 ------------------- + .../example_slurm_pretrain_pile.sub | 92 ----------------- + .../scripts_gpu/multiprocess_ft_frompile.sh | 67 +++++++++---- + .../scripts_gpu/multiprocess_pretrain_pile.sh | 70 ++++++++----- + 4 files changed, 91 insertions(+), 236 deletions(-) + delete mode 100755 t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub + delete mode 100755 t5x/contrib/gpu/scripts_gpu/example_slurm_pretrain_pile.sub + +diff --git a/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub b/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub +deleted file mode 100755 +index 19ec14e..0000000 +--- a/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub ++++ /dev/null +@@ -1,98 +0,0 @@ +-#!/bin/bash +-#SBATCH -A example # slurm account +-#SBATCH -p partition # slurm partition name +-#SBATCH -N 1 # number of nodes +-#SBATCH -t 04:00:00 # wall time +-#SBATCH -J "t5x:train" # slurm job name +-#SBATCH --exclusive # exclusive node access +-#SBATCH --mem=0 # all mem avail +-#SBATCH --mail-type=FAIL # only send email on failure +-#SBATCH --overcommit +-#SBATCH --dependency=singleton # tells slurm to run only one job with the same job name at a time +-set -x +- +-# Copyright 2022 The T5X Authors. +-# +-# Licensed under the Apache License, Version 2.0 (the "License"); +-# you may not use this file except in compliance with the License. +-# You may obtain a copy of the License at +-# +-# http://www.apache.org/licenses/LICENSE-2.0 +-# +-# Unless required by applicable law or agreed to in writing, software +-# distributed under the License is distributed on an "AS IS" BASIS, +-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-# See the License for the specific language governing permissions and +-# limitations under the License. +- +-# File system and volume glue code +-#------------------------------------------------------------------------------- +-# << CHANGE ! >> +-SLURM_ACCOUNT='example' +-USERID='exampleperson' +- +-# << CHANGE ! >> +-CONTAINER="" # Add link to your built container +- +-# << CHANGE ! >> +-BASE_T5X_DIR="...../t5x_git" # path to your clone of the repo +-BASE_TFDS_DATA_DIR="" # path to tfds data directory +-BASE_T5X_WORKSPACE_DIR="${BASE_T5X_DIR}/workspace" # path to where outputs will be dumped +- +-# Default env variables for paths required by t5x training scripts +-TFDS_DATA_DIR=/t5x_home/datasets/ +-T5X_DIR=/t5x_home/ +-T5X_WORKSPACE_DIR=/t5x_home/workspace +- +-# Add the T5x/JAX specific mounts +-MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR" +- +-# Add T5x/JAX specific exports +-EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR},PYTHONPATH=${T5X_DIR}" +-#------------------------------------------------------------------------------- +- +-FT_TASK=${FT_TASK:-mnli2} +-PREC=${PREC:="bfloat16"} +-T5_SIZE=${T5_SIZE:="large"} +-BSIZE_PER_GPU=${BSIZE_PER_GPU:=32} +-ENC_SL=${ENC_SL:=512} +-DEC_SL=${DEC_SL:=128} +-NUM_MICROBATCHES=${NUM_MICROBATCHES:=1} +-ENABLE_FP8=${ENABLE_FP8:=1} +-TP_SIZE=${TP_SIZE:=1} +-TRANSPOSE_BS=${TRANSPOSE_BS:=1} +-MODEL_DIR=${MODEL_DIR:=model_dir} +-FUSE_QKV=${FUSE_QKV:=1} +-PACK=${PACK:=0} +- +-export GPUS_PER_NODE=${1:-8} +-export BASE_SCRIPT=${2:-"${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh"} +-export WITH_MP=1 +- +-NUM_GPUS=$((GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) +- +-# redirect both stdout and stderr in the same file for ease of analysis +-OUTDIR="outputs/multinode/${TASK}_t5_${T5_SIZE}-prec_${PREC}-nodes_${SLURM_JOB_NUM_NODES}-gpus_${NUM_GPUS}-bs_${BSIZE_PER_GPU}-sl_${SL}" +- +-OUTFILE="${BASE_T5X_WORKSPACE_DIR}/${OUTDIR}/output-%j-%n.txt" +- +-LOGDIR="${T5X_WORKSPACE_DIR}/${OUTDIR}" +- +-# << CHANGE ! >> +-# You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. +-# && bash <>/bind.sh --cpu=exclusive --ib=single -- \ +-read -r -d '' cmd <> +-SLURM_ACCOUNT='example' +-USERID='exampleperson' +- +-# << CHANGE ! >> +-CONTAINER="" # Add link to your built container +- +-# << CHANGE ! >> +-BASE_T5X_DIR="...../t5x_git" # path to your clone of the repo +-BASE_TFDS_DATA_DIR="" # path to tfds data directory +-BASE_T5X_WORKSPACE_DIR="${BASE_T5X_DIR}/workspace" # path to where outputs will be dumped +- +-# Default env variables for paths required by t5x training scripts +-TFDS_DATA_DIR=/t5x_home/datasets/ +-T5X_DIR=/t5x_home/ +-T5X_WORKSPACE_DIR=/t5x_home/workspace +- +-# Add the T5x/JAX specific mounts +-MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR" +- +-# Add T5x/JAX specific exports +-EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR},PYTHONPATH=${T5X_DIR}" +-#------------------------------------------------------------------------------- +- +-# Command line arguments needed by the underlying scripts +-PREC=${PREC:="bfloat16"} +-T5_SIZE=${T5_SIZE:="large"} +-BSIZE_PER_GPU=${BSIZE_PER_GPU:=32} +-ENC_SL=${ENC_SL:=512} +-DEC_SL=${DEC_SL:=128} +-TRAIN_STEPS=${TRAIN_STEPS:=500} +-NUM_MICROBATCHES=${NUM_MICROBATCHES:=1} +-ENABLE_FP8=${ENABLE_FP8:=1} # Uses TransformerEngine FP8 +-TP_SIZE=${TP_SIZE:=1} +-TRANSPOSE_BS=${TRANSPOSE_BS:=1} # An optimization for GPUs +-MODEL_DIR=${MODEL_DIR} +-FUSE_QKV=${FUSE_QKV:=1} # Used with TransformerEngine +-PACK=${PACK:=0} # Not supported with TransformerEngine +- +-export GPUS_PER_NODE=${1:-8} +-export BASE_SCRIPT=${2:-"${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh"} +-export WITH_MP=1 +- +-NUM_GPUS=$(( GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) +- +-# << CHANGE ! >> +-# You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. +-# && bash <>/bind.sh --cpu=exclusive --ib=single -- \ +-read -r -d '' cmd < bool: ++ # TODO: Currently only checks if the library is the same. ++ # It is more robust to check the other properties of requirement ++ # but this is sufficient for us. ++ if os.environ.get('JAX_TOOLBOX_VCS_EQUIVALENCY', False) and \ ++ self.candidate.name == candidate.name and \ ++ self.candidate.source_link.is_vcs and \ ++ candidate.source_link.is_vcs \ ++ : ++ return True + return candidate == self.candidate + + +diff --git a/src/pip/_vendor/resolvelib/resolvers.py b/src/pip/_vendor/resolvelib/resolvers.py +index 2c3d0e306..1a87b3289 100644 +--- a/src/pip/_vendor/resolvelib/resolvers.py ++++ b/src/pip/_vendor/resolvelib/resolvers.py +@@ -5,6 +5,8 @@ import operator + from .providers import AbstractResolver + from .structs import DirectedGraph, IteratorMapping, build_iter_view + ++import os ++ + RequirementInformation = collections.namedtuple( + "RequirementInformation", ["requirement", "parent"] + ) +@@ -165,8 +167,25 @@ class Resolution(object): + else: + information = [RequirementInformation(requirement, parent)] + ++ # Change generator -> concrete list ++ matches = build_iter_view(matches) ++ # Force the order of matches (first appears to be used as representative ++ # for the equivalency class). This puts vcs installs with @ at the top ++ def vcs_compare(req): ++ if hasattr(req, 'source_link') \ ++ and req.source_link \ ++ and req.source_link.is_vcs \ ++ and ('@' in req.source_link.show_url): ++ # 0 is preferred and is bubbled to the top ++ return 0 ++ return 1 ++ ++ # Need to check criterion as well b/c if None then this is first instance ++ # of this identifier and this will error ++ if os.environ.get('JAX_TOOLBOX_VCS_EQUIVALENCY', False) and criterion: ++ matches = build_iter_view(sorted(list(matches), key=vcs_compare)) + criterion = Criterion( +- candidates=build_iter_view(matches), ++ candidates=matches, + information=information, + incompatibilities=incompatibilities, + ) diff --git a/.github/workflows/_build_jax.yaml b/.github/workflows/_build_jax.yaml index 3732c374a..20ebdba8d 100644 --- a/.github/workflows/_build_jax.yaml +++ b/.github/workflows/_build_jax.yaml @@ -17,36 +17,6 @@ on: description: "Build date in YYYY-MM-DD format" required: false default: 'NOT SPECIFIED' - REPO_JAX: - type: string - description: URL of JAX repository to check out - required: false - default: "https://github.com/google/jax.git" - REF_JAX: - type: string - description: Git commit, tag, or branch for JAX - required: false - default: main - REPO_XLA: - type: string - description: URL of OpenXLA repository to check out - required: false - default: "https://github.com/openxla/xla.git" - REF_XLA: - type: string - description: Git commit, tag, or branch for XLA - required: false - default: main - REPO_TE: - type: string - description: URL of transformer engine repository to check out - required: false - default: "https://github.com/NVIDIA/TransformerEngine.git" - REF_TE: - type: string - description: Git commit, tag, or branch for XLA - required: false - default: main ARTIFACT_NAME: type: string description: 'Name of the artifact zip file' @@ -67,6 +37,11 @@ on: description: 'User email in GIT to perform git pull/push' required: false default: 'jax@nvidia.com' + TRIAL_BRANCH: + type: string + description: 'Name of branch with bumped manifest and patches' + required: true + default: 'main' outputs: DOCKER_TAG_MEALKIT: description: "Tags of the 'mealkit' image built" @@ -113,6 +88,8 @@ jobs: - name: Check out the repository under ${GITHUB_WORKSPACE} uses: actions/checkout@v3 + with: + ref: ${{ inputs.TRIAL_BRANCH }} - name: Login to GitHub Container Registry uses: docker/login-action@v2 @@ -158,12 +135,6 @@ jobs: BASE_IMAGE=${{ inputs.BASE_IMAGE }} BAZEL_CACHE=${{ vars.BAZEL_REMOTE_CACHE_URL }} BUILD_DATE=${{ inputs.BUILD_DATE }} - REPO_JAX=${{ inputs.REPO_JAX }} - REPO_XLA=${{ inputs.REPO_XLA }} - REPO_TE=${{ inputs.REPO_TE }} - REF_JAX=${{ inputs.REF_JAX }} - REF_XLA=${{ inputs.REF_XLA }} - REF_TE=${{ inputs.REF_TE }} GIT_USER_NAME=${{ inputs.GIT_USER_NAME }} GIT_USER_EMAIL=${{ inputs.GIT_USER_EMAIL }} @@ -197,10 +168,6 @@ jobs: BASE_IMAGE=${{ inputs.BASE_IMAGE }} BAZEL_CACHE=${{ vars.BAZEL_REMOTE_CACHE_URL }} BUILD_DATE=${{ inputs.BUILD_DATE }} - REPO_JAX=${{ inputs.REPO_JAX }} - REPO_XLA=${{ inputs.REPO_XLA }} - REF_JAX=${{ inputs.REF_JAX }} - REF_XLA=${{ inputs.REF_XLA }} - name: Generate sitrep if: success() || failure() diff --git a/.github/workflows/_build_pax.yaml b/.github/workflows/_build_pax.yaml index de2b3dafd..5654d76c2 100644 --- a/.github/workflows/_build_pax.yaml +++ b/.github/workflows/_build_pax.yaml @@ -17,26 +17,6 @@ on: description: "Build date in YYYY-MM-DD format" required: false default: 'NOT SPECIFIED' - REPO_PAXML: - type: string - description: URL of Paxml repository to check out - required: false - default: "https://github.com/google/paxml.git" - REPO_PRAXIS: - type: string - description: URL of Praxis repository to check out - required: false - default: "https://github.com/google/praxis.git" - REF_PAXML: - type: string - description: Git commit, tag, or branch for Paxml - required: false - default: main - REF_PRAXIS: - type: string - description: Git commit, tag, or branch for Praxis - required: false - default: main ARTIFACT_NAME: type: string description: 'Name of the artifact zip file' @@ -119,10 +99,6 @@ jobs: build-args: | BASE_IMAGE=${{ inputs.BASE_IMAGE }} BUILD_DATE=${{ inputs.BUILD_DATE }} - REPO_PAXML=${{ inputs.REPO_PAXML }} - REPO_PRAXIS=${{ inputs.REPO_PRAXIS }} - REF_PAXML=${{ inputs.REF_PAXML }} - REF_PRAXIS=${{ inputs.REF_PRAXIS }} - name: Set docker metadata - final id: final-metadata @@ -151,10 +127,6 @@ jobs: build-args: | BASE_IMAGE=${{ inputs.BASE_IMAGE }} BUILD_DATE=${{ inputs.BUILD_DATE }} - REPO_PAXML=${{ inputs.REPO_PAXML }} - REPO_PRAXIS=${{ inputs.REPO_PRAXIS }} - REF_PAXML=${{ inputs.REF_PAXML }} - REF_PRAXIS=${{ inputs.REF_PRAXIS }} - name: Generate sitrep if: success() || failure() diff --git a/.github/workflows/_build_rosetta.yaml b/.github/workflows/_build_rosetta.yaml index ce0cda580..d2d8d4df1 100644 --- a/.github/workflows/_build_rosetta.yaml +++ b/.github/workflows/_build_rosetta.yaml @@ -100,7 +100,9 @@ jobs: id: mealkit-build uses: docker/build-push-action@v4 with: - context: rosetta/ + context: . + build-contexts: | + jax-toolbox=. push: true file: rosetta/Dockerfile.${{ inputs.BASE_LIBRARY }} platforms: linux/${{ inputs.ARCHITECTURE }} @@ -127,7 +129,9 @@ jobs: id: final-build uses: docker/build-push-action@v4 with: - context: rosetta/ + context: . + build-contexts: | + jax-toolbox=. push: true file: rosetta/Dockerfile.${{ inputs.BASE_LIBRARY }} platforms: linux/${{ inputs.ARCHITECTURE }} diff --git a/.github/workflows/_build_t5x.yaml b/.github/workflows/_build_t5x.yaml index 890a47141..fac2dfbfe 100644 --- a/.github/workflows/_build_t5x.yaml +++ b/.github/workflows/_build_t5x.yaml @@ -17,26 +17,6 @@ on: description: "Build date in YYYY-MM-DD format" required: false default: 'NOT SPECIFIED' - REPO_T5X: - type: string - description: URL of T5X repository to check out - required: false - default: "https://github.com/google-research/t5x.git" - REF_T5X: - type: string - description: Git commit, tag, or branch for T5X - required: false - default: main - REPO_TE: - type: string - description: URL of TE repository to check out - required: false - default: "https://github.com/NVIDIA/TransformerEngine.git" - REF_TE: - type: string - description: Git commit, tag, or branch for TE - required: false - default: main ARTIFACT_NAME: type: string description: 'Name of the artifact zip file' @@ -119,10 +99,6 @@ jobs: build-args: | BASE_IMAGE=${{ inputs.BASE_IMAGE }} BUILD_DATE=${{ inputs.BUILD_DATE }} - REPO_T5X=${{ inputs.REPO_T5X }} - REF_T5X=${{ inputs.REF_T5X }} - REPO_TE=${{ inputs.REPO_TE }} - REF_TE=${{ inputs.REF_TE }} - name: Set docker metadata - final id: final-metadata @@ -151,10 +127,6 @@ jobs: build-args: | BASE_IMAGE=${{ inputs.BASE_IMAGE }} BUILD_DATE=${{ inputs.BUILD_DATE }} - REPO_T5X=${{ inputs.REPO_T5X }} - REF_T5X=${{ inputs.REF_T5X }} - REPO_TE=${{ inputs.REPO_TE }} - REF_TE=${{ inputs.REF_TE }} - name: Generate sitrep if: success() || failure() diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 9ff53847c..fb8919221 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -7,25 +7,7 @@ on: ARCHITECTURE: type: string required: true - CUDA_IMAGE: - type: string - required: true - SRC_JAX: - type: string - required: true - SRC_XLA: - type: string - required: true - SRC_TE: - type: string - required: true - SRC_T5X: - type: string - required: true - SRC_PAXML: - type: string - required: true - SRC_PRAXIS: + TRIAL_BRANCH: type: string required: true outputs: @@ -53,18 +35,6 @@ jobs: runs-on: ubuntu-22.04 outputs: BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }} - REPO_JAX: ${{ steps.parse-inputs.outputs.REPO_JAX }} - REF_JAX: ${{ steps.parse-inputs.outputs.REF_JAX }} - REPO_XLA: ${{ steps.parse-inputs.outputs.REPO_XLA }} - REF_XLA: ${{ steps.parse-inputs.outputs.REF_XLA }} - REPO_TE: ${{ steps.parse-inputs.outputs.REPO_TE }} - REF_TE: ${{ steps.parse-inputs.outputs.REF_TE }} - REPO_T5X: ${{ steps.parse-inputs.outputs.REPO_T5X }} - REF_T5X: ${{ steps.parse-inputs.outputs.REF_T5X }} - REPO_PAXML: ${{ steps.parse-inputs.outputs.REPO_PAXML }} - REF_PAXML: ${{ steps.parse-inputs.outputs.REF_PAXML }} - REPO_PRAXIS: ${{ steps.parse-inputs.outputs.REPO_PRAXIS }} - REF_PRAXIS: ${{ steps.parse-inputs.outputs.REF_PRAXIS }} steps: - name: Check out the repository under ${GITHUB_WORKSPACE} uses: actions/checkout@v3 @@ -76,26 +46,11 @@ jobs: BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d') echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT - - name: split input "repo#ref" into repo and ref parts - id: parse-inputs - shell: bash -x -e {0} - run: | - source .github/workflows/scripts/parse_git_src.sh - - # default values are for `pull_request` event types - parse_git_src JAX "${{ inputs.SRC_JAX }}" - parse_git_src XLA "${{ inputs.SRC_XLA }}" - parse_git_src TE "${{ inputs.SRC_TE }}" - parse_git_src T5X "${{ inputs.SRC_T5X }}" - parse_git_src PAXML "${{ inputs.SRC_PAXML }}" - parse_git_src PRAXIS "${{ inputs.SRC_PRAXIS }}" - build-base: needs: metadata uses: ./.github/workflows/_build_base.yaml with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} - BASE_IMAGE: ${{ inputs.CUDA_IMAGE || 'latest' }} BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} secrets: inherit @@ -106,12 +61,7 @@ jobs: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} BASE_IMAGE: ${{ needs.build-base.outputs.DOCKER_TAG }} - REPO_JAX: ${{ needs.metadata.outputs.REPO_JAX }} - REF_JAX: ${{ needs.metadata.outputs.REF_JAX }} - REPO_XLA: ${{ needs.metadata.outputs.REPO_XLA }} - REF_XLA: ${{ needs.metadata.outputs.REF_XLA }} - REPO_TE: ${{ needs.metadata.outputs.REPO_TE }} - REF_TE: ${{ needs.metadata.outputs.REF_TE }} + TRIAL_BRANCH: ${{ inputs.TRIAL_BRANCH }} secrets: inherit build-t5x: @@ -122,8 +72,6 @@ jobs: ARCHITECTURE: amd64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }} - REPO_T5X: ${{ needs.metadata.outputs.REPO_T5X }} - REF_T5X: ${{ needs.metadata.outputs.REF_T5X }} secrets: inherit build-pax: @@ -133,10 +81,6 @@ jobs: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }} - REPO_PAXML: ${{ needs.metadata.outputs.REPO_PAXML }} - REF_PAXML: ${{ needs.metadata.outputs.REF_PAXML }} - REPO_PRAXIS: ${{ needs.metadata.outputs.REPO_PRAXIS }} - REF_PRAXIS: ${{ needs.metadata.outputs.REF_PRAXIS }} secrets: inherit build-rosetta-t5x: diff --git a/.github/workflows/_sandbox.yaml b/.github/workflows/_sandbox.yaml index cc2adc056..8d1070c9c 100644 --- a/.github/workflows/_sandbox.yaml +++ b/.github/workflows/_sandbox.yaml @@ -1,17 +1,28 @@ name: "~Sandbox" on: - # workflow_dispatch: - # push: + workflow_dispatch: + workflow_run: + workflows: + - Nightly Dsitribution test + - Nightly JAX unit test + - Nightly Transformer Engine test + - Nightly Pax MGMN performance test # The trial branch is propagated thru this workflow + - Nightly T5X MGMN performance test + - Nightly Rosetta Paxml build and test + - Nightly Rosetta T5x build and test + types: [completed] + branches: [main] permissions: - contents: read # to fetch code + contents: write # to fetch code, and create commits actions: write # to cancel previous workflows packages: write # to upload container jobs: sandbox: runs-on: ubuntu-22.04 + if: always() steps: - name: Login to GitHub Container Registry uses: docker/login-action@v2 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0949d5b92..ef4371f81 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,43 +6,11 @@ on: - '**.md' workflow_dispatch: inputs: - CUDA_IMAGE: + TRIAL_BRANCH: type: string - description: 'Base CUDA image, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04' required: false - default: 'latest' - SRC_JAX: - description: 'JAX source: #' - type: string - required: true - default: 'https://github.com/google/jax.git#main' - SRC_XLA: - description: 'XLA source: #' - type: string - required: true - default: 'https://github.com/openxla/xla.git#main' - SRC_TE: - description: 'TE source: #' - type: string - required: true - # TODO: This is a temporary pinning of TE as the API in TE no longer matches the TE patch - # This should be reverted to main ASAP - default: 'https://github.com/NVIDIA/TransformerEngine.git#main' - SRC_T5X: - description: 'T5X source: #' - type: string - required: true - default: 'https://github.com/google-research/t5x.git#main' - SRC_PAXML: - description: 'Paxml source: #' - type: string - required: true - default: 'https://github.com/google/paxml.git#main' - SRC_PRAXIS: - description: 'Praxis source: #' - type: string - required: true - default: 'https://github.com/google/praxis.git#main' + default: '' + description: 'Branch that contains manifest and patches. Default is empty which uses GITHUB_SHA' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -55,30 +23,35 @@ permissions: jobs: + metadata: + runs-on: ubuntu-22.04 + outputs: + TRIAL_BRANCH: ${{ steps.meta.outputs.TRIAL_BRANCH }} + steps: + - name: Set optional trial branch + id: meta + shell: bash -x -e {0} + run: | + if [[ -z "${{ inputs. TRIAL_BRANCH}}" ]]; then + echo "TRIAL_BRANCH=${{ github.sha }}" + else + echo "TRIAL_BRANCH=${TRIAL_BRANCH}" + fi | tee $GITHUB_OUTPUT + amd64: + needs: metadata uses: ./.github/workflows/_ci.yaml with: ARCHITECTURE: amd64 - CUDA_IMAGE: ${{ inputs.CUDA_IMAGE || 'latest' }} - SRC_JAX: ${{ inputs.SRC_JAX || 'https://github.com/google/jax.git#main' }} - SRC_XLA: ${{ inputs.SRC_XLA || 'https://github.com/openxla/xla.git#main'}} - SRC_TE: ${{ inputs.SRC_TE || 'https://github.com/NVIDIA/TransformerEngine.git#main'}} - SRC_T5X: ${{ inputs.SRC_T5X || 'https://github.com/google-research/t5x.git#main'}} - SRC_PAXML: ${{ inputs.SRC_PAXML || 'https://github.com/google/paxml.git#main'}} - SRC_PRAXIS: ${{ inputs.SRC_PRAXIS || 'https://github.com/google/praxis.git#main'}} + TRIAL_BRANCH: ${{ needs.metadata.outputs.TRIAL_BRANCH }} secrets: inherit arm64: + needs: metadata uses: ./.github/workflows/_ci.yaml with: ARCHITECTURE: arm64 - CUDA_IMAGE: ${{ inputs.CUDA_IMAGE || 'latest' }} - SRC_JAX: ${{ inputs.SRC_JAX || 'https://github.com/google/jax.git#main' }} - SRC_XLA: ${{ inputs.SRC_XLA || 'https://github.com/openxla/xla.git#main'}} - SRC_TE: ${{ inputs.SRC_TE || 'https://github.com/NVIDIA/TransformerEngine.git#main'}} - SRC_T5X: ${{ inputs.SRC_T5X || 'https://github.com/google-research/t5x.git#main'}} - SRC_PAXML: ${{ inputs.SRC_PAXML || 'https://github.com/google/paxml.git#main'}} - SRC_PRAXIS: ${{ inputs.SRC_PRAXIS || 'https://github.com/google/praxis.git#main'}} + TRIAL_BRANCH: ${{ needs.metadata.outputs.TRIAL_BRANCH }} secrets: inherit finalize: diff --git a/.github/workflows/nightly-jax-build.yaml b/.github/workflows/nightly-jax-build.yaml index ee94b86b9..3f625cd29 100644 --- a/.github/workflows/nightly-jax-build.yaml +++ b/.github/workflows/nightly-jax-build.yaml @@ -13,7 +13,7 @@ on: required: false permissions: - contents: read # to fetch code + contents: write # to fetch code, and create commits actions: write # to cancel previous workflows packages: write # to upload container @@ -38,20 +38,62 @@ jobs: run: | echo "PUBLISH=${{ github.event_name == 'schedule' || (github.event_name == 'workflow_dispatch' && inputs.PUBLISH) }}" >> $GITHUB_OUTPUT - amd64: + bump-world-state: needs: metadata + outputs: + TRIAL_BRANCH: ${{ steps.trial-meta.outputs.TRIAL_BRANCH }} + runs-on: ubuntu-22.04 + steps: + - name: Check out the repository under ${GITHUB_WORKSPACE} + uses: actions/checkout@v3 + - name: Update manifest and patches in-place - show diff + working-directory: .github/container + shell: bash -x -e {0} + run: | + bash bump.sh --input-manifest manifest.yaml + git diff + - name: Push trial branch + id: trial-meta + if: needs.metadata.outputs.PUBLISH == 'true' + shell: bash -x -e {0} + run: | + git config user.name "JAX-Toolbox CI" + git config user.email "jax@nvidia.com" + # Prepend trial branch with "z" to make it appear at the end + trial_branch=znightly-${{ github.run_id }}-${{ needs.metadata.outputs.BUILD_DATE }} + git switch -c $trial_branch + git status + git add -u + git add .github/container/patches/ + git status + git commit -m "generated ${{github.run_id }}" + git push --set-upstream origin $trial_branch + + echo "$trial_branch" > ./trial-branch.txt + echo TRIAL_BRANCH=$trial_branch | tee $GITHUB_OUTPUT + - name: 'Upload trial branch artifact' + if: needs.metadata.outputs.PUBLISH == 'true' + uses: actions/upload-artifact@v3 + with: + name: trial-branch + path: trial-branch.txt + + amd64: + needs: [metadata, bump-world-state] uses: ./.github/workflows/_build_jax.yaml with: ARCHITECTURE: amd64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} + TRIAL_BRANCH: ${{ needs.bump-world-state.outputs.TRIAL_BRANCH }} secrets: inherit arm64: - needs: metadata + needs: [metadata, bump-world-state] uses: ./.github/workflows/_build_jax.yaml with: ARCHITECTURE: arm64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} + TRIAL_BRANCH: ${{ needs.bump-world-state.outputs.TRIAL_BRANCH }} secrets: inherit publish-mealkit: diff --git a/rosetta/Dockerfile.pax b/rosetta/Dockerfile.pax index 0ba10ae75..b78d2b5c2 100644 --- a/rosetta/Dockerfile.pax +++ b/rosetta/Dockerfile.pax @@ -2,56 +2,70 @@ ARG BASE_IMAGE=ghcr.io/nvidia/upstream-pax:mealkit ARG GIT_USER_EMAIL=jax@nvidia.com ARG GIT_USER_NAME=NVIDIA -ARG SRC_PATH_PAXML=/opt/paxml -ARG SRC_PATH_PRAXIS=/opt/praxis -# These patchlist paths should be relative to this script -ARG PAXML_PATCHLIST=patchlist-paxml.txt -ARG PRAXIS_PATCHLIST=patchlist-praxis.txt +# If set to "true", then will pull new local patches, the manifest.yaml and create-distribution.sh (in case it was updated). +# This is useful for development if you run `./bump.sh -i manifest.yaml` manually and do not want to trigger a full rebuild all +# the way up to the jax build. +ARG UPDATE_PATCHES=false -FROM scratch as rosetta-source -ARG SRC_PATH_PAXML -ARG SRC_PATH_PRAXIS - -COPY . / - -FROM scratch as pax-mirror-source -ADD --keep-git-dir=true https://github.com/google/paxml.git#main / - -FROM scratch as praxis-mirror-source -ADD --keep-git-dir=true https://github.com/google/praxis.git#main / +# Rosetta and optionally patches are pulled from this +FROM scratch AS jax-toolbox ############################################################################### ### Download source and add auxiliary scripts ################################################################################ FROM ${BASE_IMAGE} AS mealkit -ENV ENABLE_TE=1 - ARG GIT_USER_EMAIL ARG GIT_USER_NAME -ARG PAXML_PATCHLIST -ARG PRAXIS_PATCHLIST +ARG UPDATE_PATCHES -COPY --from=rosetta-source / /opt/rosetta -WORKDIR /opt/rosetta -RUN --mount=target=/opt/pax-mirror,from=pax-mirror-source,readwrite \ - --mount=target=/opt/praxis-mirror,from=praxis-mirror-source,readwrite <<"EOF" bash -e +ENV ENABLE_TE=1 + +RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu +############DELETE git config --global user.email "${GIT_USER_EMAIL}" git config --global user.name "${GIT_USER_NAME}" -bash create-distribution.sh \ - -p ${PAXML_PATCHLIST} \ - -m https://github.com/nvjax-svc-0/paxml.git \ - -d /opt/paxml \ - -e /opt/pax-mirror -bash create-distribution.sh \ - -p ${PRAXIS_PATCHLIST} \ - -m https://github.com/nvjax-svc-0/praxis.git \ - -d /opt/praxis \ - -e /opt/praxis-mirror +git clone -b 23.3.1 https://github.com/pypa/pip.git /opt/pip +cp /mnt/jax-toolbox/.github/container/pip-vcs-equivalency.patch /opt/pip/ +cd /opt/pip +git apply > /opt/pip-tools.d/manifest.t5x -echo "-e file:///opt/rosetta" >> /opt/pip-tools.d/manifest.t5x +bash ${MANIFEST_DIR}/create-distribution.sh \ + --manifest ${MANIFEST_FILE} \ + --package t5x +bash ${MANIFEST_DIR}/create-distribution.sh \ + --manifest ${MANIFEST_FILE} \ + --package flax +# Remove .gitconfig to avoid end-user authoring commits as the "build user" +rm -f ~/.gitconfig + +echo "--extra-index-url https://developer.download.nvidia.com/compute/redist" >> /opt/pip-tools.d/requirements-t5x.in +echo "-e file:///opt/rosetta" >> /opt/pip-tools.d/requirements-t5x.in EOF WORKDIR /opt/rosetta +COPY --from=jax-toolbox rosetta/tests/test-vit.sh /usr/local/bin ############################################################################### ### Install accumulated packages from the base image and the previous stage diff --git a/rosetta/README.md b/rosetta/README.md index c91a4ed3f..dfd01c750 100644 --- a/rosetta/README.md +++ b/rosetta/README.md @@ -6,22 +6,25 @@ LLM, CV, and multimodal models. ### Building rosetta with a specific base ```bash -ROSETTA_BASE=t5x # or pax +cd JAX-Toolbox -docker buildx build --tag rosetta:latest -f Dockerfile.${ROSETTA_BASE} . +ROSETTA_BASE=pax # or t5x + +docker buildx build --build-context jax-toolbox=. --tag rosetta-${ROSETTA_BASE}:latest -f rosetta/Dockerfile.${ROSETTA_BASE} . # If you want to specify a specific base image -docker buildx build --tag rosetta:latest -f Dockerfile.${ROSETTA_BASE} --build-arg BASE_IMAGE=ghcr.io/nvidia/${ROSETTA_BASE}:mealkit-YYYY-MM-DD . +docker buildx build --build-context jax-toolbox=. --tag rosetta-${ROSETTA_BASE}:latest -f rosetta/Dockerfile.${ROSETTA_BASE} --build-arg BASE_IMAGE=ghcr.io/nvidia/upstream-${ROSETTA_BASE}:mealkit-YYYY-MM-DD . ``` ### Advanced use-cases ```sh -# [T5x Example] If you want to build with a different patchlist (patchlist must be relative to rosetta dir) -docker buildx build --build-arg T5X_PATCHLIST=patches/t5x/patchlist-t5x.txt.gen --build-arg FLAX_PATCHLIST=patches/flax/patchlist-flax.txt.gen --target rosetta --tag rosetta:latest -f Dockerfile.t5x . +# If you want to build with updated patches +cd JAX-Toolbox + +ROSETTA_BASE=pax -# [T5x Example] If you want to build with patches from another image -scripts/extract-patches.sh # Extracts generated patch dir under ./patches/ -docker buildx build --build-arg T5X_PATCHLIST=patches/t5x/patchlist-t5x.txt.gen --build-arg FLAX_PATCHLIST=patches/flax/patchlist-flax.txt.gen --target rosetta --tag rosetta:latest -f Dockerfile.t5x . +bash .github/container/bump.sh -i .github/container/manifest.yaml +docker buildx build --build-context jax-toolbox=. --tag rosetta-${ROSETTA_BASE}:latest -f rosetta/Dockerfile.${ROSETTA_BASE} --build-arg UPDATE_PATCHES=true . ``` ## Development diff --git a/rosetta/patchlist-flax.txt b/rosetta/patchlist-flax.txt deleted file mode 100644 index 335d9f0f3..000000000 --- a/rosetta/patchlist-flax.txt +++ /dev/null @@ -1,8 +0,0 @@ -############## -# Patch list # -############## -# - Internal patches (These are branches that start with "patch/") -# - External Pull Requests (These are pull requests with upstream flax and are of the form "pull/$PULLID/head") -# - Note: Only the first column is used as a git-ref, so anything after is a comment - -pull/3340/head # Add Sharding Annotations to Flax Modules diff --git a/rosetta/patchlist-paxml.txt b/rosetta/patchlist-paxml.txt deleted file mode 100644 index a7a67f2ab..000000000 --- a/rosetta/patchlist-paxml.txt +++ /dev/null @@ -1,8 +0,0 @@ -############## -# Patch list # -############## -# - Internal patches (These are branches that start with "patch/") -# - External Pull Requests (These are pull requests with upstream paxml and are of the form "pull/$PULLID/head") -# - Note: Only the first column is used as a git-ref, so anything after is a comment - -pull/46/head # adds Transformer Engine support diff --git a/rosetta/patchlist-praxis.txt b/rosetta/patchlist-praxis.txt deleted file mode 100644 index 104103a33..000000000 --- a/rosetta/patchlist-praxis.txt +++ /dev/null @@ -1,8 +0,0 @@ -############## -# Patch list # -############## -# - Internal patches (These are branches that start with "patch/") -# - External Pull Requests (These are pull requests with upstream praxis and are of the form "pull/$PULLID/head") -# - Note: Only the first column is used as a git-ref, so anything after is a comment -pull/27/head # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas. -pull/36/head # adds Transformer Engine support diff --git a/rosetta/patchlist-t5x.txt b/rosetta/patchlist-t5x.txt deleted file mode 100644 index 11285c6df..000000000 --- a/rosetta/patchlist-t5x.txt +++ /dev/null @@ -1,10 +0,0 @@ -############## -# Patch list # -############## -# - Internal patches (These are branches that start with "patch/") -# - External Pull Requests (These are pull requests with upstream t5x and are of the form "pull/$PULLID/head") -# - Note: Only the first column is used as a git-ref, so anything after is a comment - -mirror/patch/partial-checkpoint-restore # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore -mirror/patch/dali-support # pull/1393/head # https://github.com/google-research/t5x/pull/1393: Adds DALI support to t5x -mirror/patch/t5x_te_in_contrib_noindent # pull/1391/head # https://github.com/google-research/t5x/pull/1391: Adds transformer engine support and GPU optimizations to T5x (enables H100) \ No newline at end of file diff --git a/rosetta/tests/extra-only-distribution.sh b/rosetta/tests/extra-only-distribution.sh index 5c21b64b3..73257042d 100755 --- a/rosetta/tests/extra-only-distribution.sh +++ b/rosetta/tests/extra-only-distribution.sh @@ -6,31 +6,50 @@ cd $SCRIPT_DIR set -eou pipefail # Version should work on Linux || Darwin -#repo_tmp=$(mktemp -d 2>/dev/null || mktemp -d -t 'mytmpdir') -repo_tmp=/tmp/t5x -extra_tmp=/tmp/extra -patchlist_tmp=$(mktemp /tmp/patchlist.txt.XXXXXX) +repo_tmp=$(mktemp -d 2>/dev/null || mktemp -d -t 'upstream') +extra_tmp=$(mktemp -d 2>/dev/null || mktemp -d -t 'extra') +workspace_tmp=$(mktemp -d 2>/dev/null || mktemp -d -t 'workspace') +manifest_tmp=$(mktemp /tmp/manifest.yaml.XXXXXX) +LIBRARY=t5x UPSTREAM_URL=https://github.com/google-research/t5x.git # Commit was taken just before PR-1372 DISTRIBUTION_BASE_REF=22117ce5a3606706ba9519ccdd77b532ad8ff7b2 +EXTRA_PATCH_BRANCH=patch/delete-readme git clone $UPSTREAM_URL $repo_tmp git clone $UPSTREAM_URL $extra_tmp +git -C $repo_tmp checkout $DISTRIBUTION_BASE_REF git -C $extra_tmp checkout $DISTRIBUTION_BASE_REF -echo "patch/delete-readme" >> $patchlist_tmp cd $extra_tmp -git switch -c patch/delete-readme +git switch -c $EXTRA_PATCH_BRANCH git rm README.md git commit -m 'TEST DELETE README' cd - -bash ../create-distribution.sh \ - -r $DISTRIBUTION_BASE_REF \ - -d $repo_tmp \ - -e $extra_tmp \ - -p $patchlist_tmp +cat <> $manifest_tmp +t5x: + url: https://github.com/google-research/t5x.git + mirror_url: https://github.com/nvjax-svc-0/t5x.git + extra_dir: $extra_tmp + tracking_ref: main + ref: $DISTRIBUTION_BASE_REF + mode: git-clone + patches: + $EXTRA_PATCH_BRANCH: null +EOF + +cp ../../.github/container/create-distribution.sh $workspace_tmp/ +base_cmd() { + bash $workspace_tmp/create-distribution.sh \ + --manifest $manifest_tmp \ + --override_dir $repo_tmp \ + --package $LIBRARY \ + $@ +} +base_cmd --skip-apply +base_cmd # TESTS EXPECTED_HEAD_COMMIT_MSG="*TEST DELETE README" @@ -50,5 +69,5 @@ elif [[ "$PENULTIMATE_COMMIT_MSG" == "$EXPECTED_PENULTIMATE_COMMIT_MSG" ]]; then exit 1 fi -rm -rf $repo_tmp $patchlist_tmp $extra_tmp +rm -rf $repo_tmp $manifest_tmp $workspace_tmp $extra_tmp echo "TEST SUCCESS" diff --git a/rosetta/tests/mirror-only-distribution.sh b/rosetta/tests/mirror-only-distribution.sh index 73331802c..1726dab61 100755 --- a/rosetta/tests/mirror-only-distribution.sh +++ b/rosetta/tests/mirror-only-distribution.sh @@ -6,23 +6,38 @@ cd $SCRIPT_DIR set -eou pipefail # Version should work on Linux || Darwin -#repo_tmp=$(mktemp -d 2>/dev/null || mktemp -d -t 'mytmpdir') -repo_tmp=/tmp/t5x -patchlist_tmp=$(mktemp /tmp/patchlist.txt.XXXXXX) +tmp_base=$(mktemp -d 2>/dev/null || mktemp -d -t 'upstream') +workspace_tmp=$(mktemp -d 2>/dev/null || mktemp -d -t 'workspace') +manifest_tmp=$(mktemp /tmp/manifest.yaml.XXXXXX) -UPSTREAM_URL=https://github.com/google-research/t5x.git -MIRROR_URL=https://github.com/nvjax-svc-0/t5x.git +LIBRARY=t5x # Commit was taken just before PR-1372 DISTRIBUTION_BASE_REF=22117ce5a3606706ba9519ccdd77b532ad8ff7b2 - -git clone $UPSTREAM_URL $repo_tmp -echo "mirror/pull/4/head" >> $patchlist_tmp - -bash ../create-distribution.sh \ - -r $DISTRIBUTION_BASE_REF \ - -d $repo_tmp \ - -m $MIRROR_URL \ - -p $patchlist_tmp +repo_tmp=$tmp_base/$LIBRARY + +cat <> $manifest_tmp +t5x: + url: https://github.com/google-research/t5x.git + mirror_url: https://github.com/nvjax-svc-0/t5x.git + tracking_ref: main + ref: $DISTRIBUTION_BASE_REF + mode: git-clone + patches: + mirror/pull/4/head: null +EOF + +bash ../../.github/container/get-source.sh --base-dir $tmp_base --manifest $manifest_tmp --library $LIBRARY + +cp ../../.github/container/create-distribution.sh $workspace_tmp/ +base_cmd() { + bash $workspace_tmp/create-distribution.sh \ + --manifest $manifest_tmp \ + --override_dir $repo_tmp \ + --package $LIBRARY \ + $@ +} +base_cmd --skip-apply +base_cmd # TESTS EXPECTED_HEAD_COMMIT_MSG="*TEST - DELETE README" @@ -42,5 +57,5 @@ elif [[ "$PENULTIMATE_COMMIT_MSG" == "$EXPECTED_PENULTIMATE_COMMIT_MSG" ]]; then exit 1 fi -rm -rf $repo_tmp $patchlist_tmp +rm -rf $repo_tmp $manifest_tmp $workspace_tmp echo "TEST SUCCESS" diff --git a/rosetta/test-vit.sh b/rosetta/tests/test-vit.sh similarity index 100% rename from rosetta/test-vit.sh rename to rosetta/tests/test-vit.sh diff --git a/rosetta/tests/upstream-only-distribution.sh b/rosetta/tests/upstream-only-distribution.sh index 7a3c0f45d..a314534b6 100755 --- a/rosetta/tests/upstream-only-distribution.sh +++ b/rosetta/tests/upstream-only-distribution.sh @@ -6,20 +6,37 @@ cd $SCRIPT_DIR set -eou pipefail # Version should work on Linux || Darwin -repo_tmp=$(mktemp -d 2>/dev/null || mktemp -d -t 'mytmpdir') -patchlist_tmp=$(mktemp /tmp/patchlist.txt.XXXXXX) +tmp_base=$(mktemp -d 2>/dev/null || mktemp -d -t 'upstream') +workspace_tmp=$(mktemp -d 2>/dev/null || mktemp -d -t 'workspace') +manifest_tmp=$(mktemp /tmp/manifest.yaml.XXXXXX) -UPSTREAM_URL=https://github.com/google-research/t5x.git +LIBRARY=t5x # Commit was taken just before PR-1372 DISTRIBUTION_BASE_REF=22117ce5a3606706ba9519ccdd77b532ad8ff7b2 - -git clone $UPSTREAM_URL $repo_tmp -echo "pull/1372/head" >> $patchlist_tmp - -bash ../create-distribution.sh \ - -r $DISTRIBUTION_BASE_REF \ - -d $repo_tmp \ - -p $patchlist_tmp +repo_tmp=$tmp_base/$LIBRARY + +cat <> $manifest_tmp +t5x: + url: https://github.com/google-research/t5x.git + mirror_url: https://github.com/nvjax-svc-0/t5x.git + tracking_ref: main + ref: $DISTRIBUTION_BASE_REF + mode: git-clone + patches: + pull/1372/head: null +EOF +bash ../../.github/container/get-source.sh --base-dir $tmp_base --manifest $manifest_tmp --library $LIBRARY + +cp ../../.github/container/create-distribution.sh $workspace_tmp/ +base_cmd() { + bash $workspace_tmp/create-distribution.sh \ + --manifest $manifest_tmp \ + --override_dir $repo_tmp \ + --package $LIBRARY \ + $@ +} +base_cmd --skip-apply +base_cmd # TESTS EXPECTED_HEAD_COMMIT_MSG=*"Support batched indices in PositionEmbed. This is useful to support prefilling caches for prompted decoding with batches containing prompts of different lengths." @@ -39,5 +56,5 @@ elif [[ "$PENULTIMATE_COMMIT_MSG" == "$EXPECTED_PENULTIMATE_COMMIT_MSG" ]]; then exit 1 fi -rm -rf $repo_tmp $patchlist_tmp +rm -rf $repo_tmp $manifest_tmp $workspace_tmp echo "TEST SUCCESS" From 4951c31a8dfdedca798c8300897251d79cc0ddfe Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 1 Dec 2023 12:31:44 -0800 Subject: [PATCH 05/29] fix pip patch --- .github/container/pip-vcs-equivalency.patch | 35 +++++++++++++++------ 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/.github/container/pip-vcs-equivalency.patch b/.github/container/pip-vcs-equivalency.patch index 724ed76b2..947859e21 100644 --- a/.github/container/pip-vcs-equivalency.patch +++ b/.github/container/pip-vcs-equivalency.patch @@ -1,5 +1,5 @@ diff --git a/src/pip/_internal/resolution/resolvelib/requirements.py b/src/pip/_internal/resolution/resolvelib/requirements.py -index 7d1e7bfdd..d67042896 100644 +index 7d1e7bfdd..5b954cfb5 100644 --- a/src/pip/_internal/resolution/resolvelib/requirements.py +++ b/src/pip/_internal/resolution/resolvelib/requirements.py @@ -5,6 +5,7 @@ from pip._internal.req.constructors import install_req_drop_extras @@ -10,13 +10,20 @@ index 7d1e7bfdd..d67042896 100644 class ExplicitRequirement(Requirement): -@@ -37,6 +38,15 @@ class ExplicitRequirement(Requirement): +@@ -37,6 +38,22 @@ class ExplicitRequirement(Requirement): return self.candidate, None def is_satisfied_by(self, candidate: Candidate) -> bool: + # TODO: Currently only checks if the library is the same. + # It is more robust to check the other properties of requirement -+ # but this is sufficient for us. ++ # but this is sufficient for us. Future checks may include: ++ # - check git-ref/SHA from candidate.source_link.(path||show_url) and return False if conflicting SHA/git-ref: ++ # - ex[path]: '/google/fiddle@cd4497e4c09bdf95dcccaa1e138c2c125d32d39f' vs. '/google/fiddle' ++ # - ex[show_url]: 'fiddle@cd4497e4c09bdf95dcccaa1e138c2c125d32d39f' vs. 'fiddle' ++ # - check VCS provider from candidate.source_link.netloc: ++ # - ex: 'github.com' ++ # - check VCS protoocol from candidate.source_link.scheme: ++ # - ex: 'git+https' + if os.environ.get('JAX_TOOLBOX_VCS_EQUIVALENCY', False) and \ + self.candidate.name == candidate.name and \ + self.candidate.source_link.is_vcs and \ @@ -27,7 +34,7 @@ index 7d1e7bfdd..d67042896 100644 diff --git a/src/pip/_vendor/resolvelib/resolvers.py b/src/pip/_vendor/resolvelib/resolvers.py -index 2c3d0e306..1a87b3289 100644 +index 2c3d0e306..5a8e263a5 100644 --- a/src/pip/_vendor/resolvelib/resolvers.py +++ b/src/pip/_vendor/resolvelib/resolvers.py @@ -5,6 +5,8 @@ import operator @@ -39,11 +46,10 @@ index 2c3d0e306..1a87b3289 100644 RequirementInformation = collections.namedtuple( "RequirementInformation", ["requirement", "parent"] ) -@@ -165,8 +167,25 @@ class Resolution(object): +@@ -165,8 +167,35 @@ class Resolution(object): else: information = [RequirementInformation(requirement, parent)] -+ # Change generator -> concrete list + matches = build_iter_view(matches) + # Force the order of matches (first appears to be used as representative + # for the equivalency class). This puts vcs installs with @ at the top @@ -56,9 +62,20 @@ index 2c3d0e306..1a87b3289 100644 + return 0 + return 1 + -+ # Need to check criterion as well b/c if None then this is first instance -+ # of this identifier and this will error -+ if os.environ.get('JAX_TOOLBOX_VCS_EQUIVALENCY', False) and criterion: ++ # Need to check: ++ # - criterion: if None then this is first instance of this identifier ++ # - requirement.candidate.source_link: only vcs requirements will have this ++ # so if no VCS req present, can skip this logic ++ # - type(matches._sequence) == list if VCS req, otherwise it's an iterator ++ # that we should not concretize into a list since that will force us ++ # to search through the candidates and that may error ++ if os.environ.get('JAX_TOOLBOX_VCS_EQUIVALENCY', False) \ ++ and criterion \ ++ and hasattr(requirement, 'candidate') \ ++ and requirement.candidate.source_link \ ++ and requirement.candidate.source_link.is_vcs \ ++ and isinstance(matches._sequence, list) \ ++ : + matches = build_iter_view(sorted(list(matches), key=vcs_compare)) criterion = Criterion( - candidates=build_iter_view(matches), From 7a852f8f901668bb15ca0359b777eb2972e28e53 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 1 Dec 2023 15:27:22 -0800 Subject: [PATCH 06/29] fix pip-finalize --- .github/container/create-distribution.sh | 5 ++--- .github/container/pip-finalize.sh | 9 +++++---- .github/workflows/nightly-jax-build.yaml | 2 +- rosetta/Dockerfile.pax | 5 ++++- rosetta/Dockerfile.t5x | 5 ++++- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/.github/container/create-distribution.sh b/.github/container/create-distribution.sh index 3fbb8cef6..8345fb337 100755 --- a/.github/container/create-distribution.sh +++ b/.github/container/create-distribution.sh @@ -104,7 +104,6 @@ if [[ -z "$MANIFEST" || -z "$PACKAGE" ]]; then usage 1 fi -BASE_DIR=${BASE_DIR:-/opt} CLEAN_PATCHES=${CLEAN_PATCHES:-0} UPSTREAM_URL=$(yq e ".${PACKAGE}.url" $MANIFEST) # The tracking_ref is interpreted as the default "main" branch and all patches are @@ -115,7 +114,7 @@ MIRROR_GIT_URL=$(yq e ".${PACKAGE}.mirror_url // \"\"" $MANIFEST) EXTRA_DIR=$(yq e ".${PACKAGE}.extra_dir // \"\"" $MANIFEST) SKIP_APPLY=${SKIP_APPLY:-0} -GEN_PATCH_DIR=${GEN_PATCH_DIR:-$SCRIPT_DIR/patches/$PACKAGE} +GEN_PATCH_DIR=$SCRIPT_DIR/patches/$PACKAGE # Associative arrays aren't available before bash <4.0, so maintaining separate key/value arrays PATCH_KEYS=() PATCH_VALUES=() @@ -365,4 +364,4 @@ for remote in ${MIRROR_REMOTE_NAME} ${EXTRA_REMOTE_NAME:-}; do done if [[ -n "${EXTRA_REMOTE_NAME+x}" ]]; then git+extra branch --list "*${TMP_BRANCH_SUFFIX}" | xargs -I@ git -C ${EXTRA_DIR} branch -d @ -fi \ No newline at end of file +fi diff --git a/.github/container/pip-finalize.sh b/.github/container/pip-finalize.sh index d26c48424..34244c164 100755 --- a/.github/container/pip-finalize.sh +++ b/.github/container/pip-finalize.sh @@ -11,7 +11,7 @@ pushd /opt/pip-tools.d pip-compile -o requirements.pre $(ls requirements-*.in) IFS=$'\n' -for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+'); do +for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+' || true); do # VCS installs are of the form "PACKAGE @ git+..." PACKAGE=$(echo "$line" | awk '{print $1}') ref=$(yq e ".${PACKAGE}.ref" ${MANIFEST_FILE}) @@ -30,13 +30,14 @@ unset IFS # that treats the above as equivalent and prefers the URI wit the SHA JAX_TOOLBOX_VCS_EQUIVALENCY=true pip-compile -o requirements.txt requirements.vcs $(ls requirements-*.in) -unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@') -if [[ $(echo "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then +unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@' || true) +if [[ $(echo -n "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then echo "Unpinned VCS installs found in $(readlink -f requirements.txt):" echo "$unpinned_vcs_dependencies" exit 1 fi -pip-sync --pip-args '--src /opt' requirements.txt +# --no-deps is required since conflicts can still appear during pip-sync +pip-sync --pip-args '--no-deps --src /opt' requirements.txt rm -rf ~/.cache/* diff --git a/.github/workflows/nightly-jax-build.yaml b/.github/workflows/nightly-jax-build.yaml index 3f625cd29..5f13884f3 100644 --- a/.github/workflows/nightly-jax-build.yaml +++ b/.github/workflows/nightly-jax-build.yaml @@ -60,7 +60,7 @@ jobs: git config user.name "JAX-Toolbox CI" git config user.email "jax@nvidia.com" # Prepend trial branch with "z" to make it appear at the end - trial_branch=znightly-${{ github.run_id }}-${{ needs.metadata.outputs.BUILD_DATE }} + trial_branch=znightly-${{ needs.metadata.outputs.BUILD_DATE }}-${{ github.run_id }} git switch -c $trial_branch git status git add -u diff --git a/rosetta/Dockerfile.pax b/rosetta/Dockerfile.pax index b78d2b5c2..de9e0d71a 100644 --- a/rosetta/Dockerfile.pax +++ b/rosetta/Dockerfile.pax @@ -26,7 +26,10 @@ RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu git config --global user.email "${GIT_USER_EMAIL}" git config --global user.name "${GIT_USER_NAME}" -git clone -b 23.3.1 https://github.com/pypa/pip.git /opt/pip +if [[ ! -d /opt/pip ]]; then + git clone https://github.com/pypa/pip.git /opt/pip +fi +git -C /opt/pip checkout 23.3.1 cp /mnt/jax-toolbox/.github/container/pip-vcs-equivalency.patch /opt/pip/ cd /opt/pip git apply Date: Fri, 1 Dec 2023 17:22:20 -0800 Subject: [PATCH 07/29] add praxis TE patch from upstream --- .github/container/manifest.yaml | 25 +- .github/container/patches/praxis/PR-36.patch | 362 +++++++++++++++++++ 2 files changed, 375 insertions(+), 12 deletions(-) create mode 100644 .github/container/patches/praxis/PR-36.patch diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 6077254ad..20d1071f9 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -1,32 +1,32 @@ -# Updated in: XXX +# Updated in: 2023-12-01 jax: url: https://github.com/google/jax.git tracking_ref: main - ref: b032a0271e3e2ea8d0df64d2f3f1a1e450a38dc9 # 2023-11-15 + ref: 595117b70c11055e569480b80907d8c8a9901805 mode: git-clone xla: url: https://github.com/openxla/xla.git tracking_ref: main - ref: 8fb606ffa03c030035d6c0c9d05534dbf6701906 # 2023-11-15 + ref: 78a5297d8e4301cb3ba2514061f56f89104e3d88 mode: git-clone flax: url: https://github.com/google/flax.git mirror_url: https://github.com/nvjax-svc-0/flax.git tracking_ref: main - ref: a572f6af2fef565c0f9ba2fc12b781e9e3385140 + ref: 230b0d77e98da22b6e574c3cbff743ca1504bfca mode: git-clone patches: pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules transformer-engine: url: https://github.com/NVIDIA/TransformerEngine.git tracking_ref: main - ref: d76118d90df0422d52261adc26a5f4351a1dd71f + ref: 92c1e500dd14608e54f75df8276baa1104c61d48 mode: git-clone t5x: url: https://github.com/google-research/t5x.git mirror_url: https://github.com/nvjax-svc-0/t5x.git tracking_ref: main - ref: c39a33a35bb2f03f6d36455e6378620c6634a995 + ref: 1bfd2f15e5e77b09d60301367f67fdc9bb756b46 mode: git-clone patches: mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore @@ -36,7 +36,7 @@ paxml: url: https://github.com/google/paxml.git mirror_url: https://github.com/nvjax-svc-0/paxml.git tracking_ref: main - ref: 6c811d5e8f82a8aa75530b50223302d98f47e984 + ref: 7ae682d4d99630008e190b96c5296990297175c2 mode: git-clone patches: pull/46/head: file://patches/paxml/PR-46.patch # adds Transformer Engine support @@ -44,15 +44,16 @@ praxis: url: https://github.com/google/praxis.git mirror_url: https://github.com/nvjax-svc-0/praxis.git tracking_ref: main - ref: fcadc09773e32a18abd5b0240e07da33316a9636 + ref: b6f32fa0fc6721db1cec75972b0f569c82095956 mode: git-clone patches: pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas. + pull/36/head: file://patches/praxis/PR-36.patch # adds Transformer Engine support lingvo: # Used only in ARM pax builds url: https://github.com/tensorflow/lingvo.git tracking_ref: master - ref: 36a1e314864533eeb1cc1e6590e86c10c03b1516 + ref: 0274fa20b4ff194c1c118b94b5f778caa5d9a84a mode: git-clone tensorflow-text: # Used only in ARM pax builds @@ -67,7 +68,7 @@ pydantic: fiddle: url: https://github.com/google/fiddle.git tracking_ref: main - ref: b13db6481720bc897f4efd7e04c7ba4f5907ce74 + ref: d409cf95164599a88e49d2b6a23a0972a7170b0b mode: pip-vcs # Used by t5x airio: @@ -87,8 +88,8 @@ dllogger: mode: pip-vcs jestimator: url: https://github.com/google-research/jestimator.git - tracking_ref: master - ref: fa143d93e337ca8ab77c4510baf21ae52af24ab2 + tracking_ref: main + ref: "fa143d93e337ca8ab77c4510baf21ae52af24ab2" mode: pip-vcs optax: url: https://github.com/deepmind/optax.git diff --git a/.github/container/patches/praxis/PR-36.patch b/.github/container/patches/praxis/PR-36.patch new file mode 100644 index 000000000..d170f4051 --- /dev/null +++ b/.github/container/patches/praxis/PR-36.patch @@ -0,0 +1,362 @@ +From 41488517eb6d95eb7943681e706c8804e6102c41 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Wed, 15 Nov 2023 11:38:27 +0800 +Subject: [PATCH 1/2] Adding TE support + +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 176 ++++++++++++++++++++ + praxis/layers/transformers.py | 22 +-- + 2 files changed, 181 insertions(+), 17 deletions(-) + create mode 100644 praxis/contrib/gpu/scripts_gpu/te_helper.py + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +new file mode 100644 +index 0000000..2d5277e +--- /dev/null ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -0,0 +1,176 @@ ++import os ++ ++from praxis import base_layer ++from praxis import pax_fiddle ++from praxis import pytypes ++ ++try: ++ import transformer_engine.jax as te ++ import transformer_engine.jax.flax as te_flax ++ import transformer_engine.jax.praxis as te_praxis ++ _IS_TRANSFORMER_ENGINE_INSTALLED = True ++ import praxis.layers.repeats as praxis_repeat ++ # This is to make Repeat module correctly generate collections we need. ++ praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes ++ te.fp8.FP8Helper.FP8_COLLECTION_NAME:0}) ++ ++except ModuleNotFoundError as e: ++ _IS_TRANSFORMER_ENGINE_INSTALLED = False ++ ++ ++LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] ++JTensor = pytypes.JTensor ++ ++ ++class TransformerEngineHelperBase: ++ ++ @staticmethod ++ def get_fprop_caller_of_stack_transformer(fprop, deterministic): ++ raise NotImplementedError ++ ++ @staticmethod ++ def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, layer_id): ++ raise NotImplementedError ++ ++ @staticmethod ++ def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): ++ raise NotImplementedError ++ ++ ++ ++class TENotInstalledHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def get_fprop_caller_of_stack_transformer(fprop, deterministic): ++ return fprop ++ ++ @staticmethod ++ def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, layer_id): ++ layer_p.name = f'layer_{layer_id}' ++ layer_p.use_cross_attention = stacked_transformer_obj.use_cross_attention ++ layer_p.num_heads = stacked_transformer_obj.num_heads ++ layer_p.dim_per_head = stacked_transformer_obj.dim_per_head ++ layer_p.input_dims = stacked_transformer_obj.model_dims ++ layer_p.packed_input = stacked_transformer_obj.packed_input ++ layer_p.atten_dropout_prob = stacked_transformer_obj.atten_dropout_prob or stacked_transformer_obj.dropout_prob ++ layer_p.residual_dropout_prob = ( ++ stacked_transformer_obj.residual_dropout_prob or stacked_transformer_obj.dropout_prob ++ ) ++ layer_p.relu_dropout_prob = stacked_transformer_obj.relu_dropout_prob or stacked_transformer_obj.dropout_prob ++ layer_p.hidden_dims = stacked_transformer_obj.hidden_dims ++ ++ if stacked_transformer_obj.residual_droppath_prob > 0.0: ++ layer_p.residual_droppath_prob = ( ++ stacked_transformer_obj.residual_droppath_prob * layer_id / max(1, stacked_transformer_obj.num_layers) ++ ) ++ return layer_p ++ ++ @staticmethod ++ def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): ++ return xformer_layer_p.tr_atten_tpl.activation_split_dims_mapping.bld ++ ++ ++class TEInstalledHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def get_fprop_caller_of_stack_transformer(_, deterministic): ++ def _fprop( ++ transformer, ++ x_in, ++ paddings, ++ attention_mask, ++ cross_inputs, ++ cross_attention_mask, ++ segment_pos ++ ): ++ x_out = transformer( ++ inputs=x_in, ++ attention_mask=attention_mask, ++ encoded=cross_inputs, ++ encoder_decoder_mask=cross_attention_mask, ++ deterministic=deterministic) ++ return x_out ++ return _fprop ++ ++ ++ @staticmethod ++ def set_layer_params_to_stack_transformer(stacked_transformer_obj, _, layer_id): ++ te_transformer_tpl = pax_fiddle.Config(te_praxis.TransformerLayer, ++ name=f'layer_{layer_id}', ++ layernorm_type='layernorm', ++ zero_centered_gamma = True, ++ mlp_activations=('gelu',), ++ use_bias=True, ++ self_attn_mask_type='causal', ++ enable_relative_embedding=False, ++ scaled_query_init=False, ++ scale_attn_logits=True, ++ transpose_batch_sequence=False ++ ) ++ ++ te_transformer_tpl.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple()) ++ te_transformer_tpl.params_init = stacked_transformer_obj.params_init ++ te_transformer_tpl.dtype = stacked_transformer_obj.fprop_dtype ++ te_transformer_tpl.layer_type = te_praxis.TransformerLayerType.DECODER if stacked_transformer_obj.use_cross_attention \ ++ else te_praxis.TransformerLayerType.ENCODER ++ te_transformer_tpl.num_attention_heads = stacked_transformer_obj.num_heads ++ te_transformer_tpl.hidden_size = stacked_transformer_obj.model_dims ++ te_transformer_tpl.mlp_hidden_size = stacked_transformer_obj.hidden_dims ++ te_transformer_tpl.layernorm_epsilon = stacked_transformer_obj.transformer_layer_params_tpl.ln_tpl.epsilon ++ te_transformer_tpl.dropout_rng_name = base_layer.RANDOM ++ te_transformer_tpl.attention_dropout = stacked_transformer_obj.atten_dropout_prob or stacked_transformer_obj.dropout_prob ++ te_transformer_tpl.hidden_dropout = stacked_transformer_obj.residual_dropout_prob or stacked_transformer_obj.dropout_prob ++ te_transformer_tpl.intermediate_dropout = stacked_transformer_obj.relu_dropout_prob or stacked_transformer_obj.dropout_prob ++ if stacked_transformer_obj.residual_droppath_prob > 0.0: ++ te_transformer_tpl.drop_path = ( ++ stacked_transformer_obj.residual_droppath_prob * layer_id / max(1, stacked_transformer_obj.num_layers) ++ ) ++ ++ assert stacked_transformer_obj.dim_per_head == stacked_transformer_obj.model_dims // stacked_transformer_obj.num_heads ++ assert stacked_transformer_obj.packed_input == False ++ assert len(stacked_transformer_obj.moe_layers) == 0 ++ assert stacked_transformer_obj.ngrammer_tpls is None ++ ++ return te_transformer_tpl ++ ++ @staticmethod ++ def get_bld_mapping_for_pipelined_transformer(_): ++ rules = te_flax.extend_logical_axis_rules(tuple()) ++ batch_mapping = rules[0] ++ hidden_tp_mapping = rules[4] ++ # [Batch, Seqlen, Hidden] ++ bld_mapping = [batch_mapping, None, hidden_tp_mapping] ++ return bld_mapping ++ ++ ++ ++ ++class TransformerEngineHelper(TransformerEngineHelperBase): ++ ++ @staticmethod ++ def is_enabled_te(): ++ enable_te = bool(int((os.environ.get("ENABLE_TE", False)))) ++ return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te) ++ ++ @staticmethod ++ def get_helper(): ++ if TransformerEngineHelper.is_enabled_te(): ++ return TEInstalledHelper ++ return TENotInstalledHelper ++ ++ @staticmethod ++ def get_fprop_caller_of_stack_transformer(fprop, deterministic): ++ return TransformerEngineHelper.get_helper().get_fprop_caller_of_stack_transformer( ++ fprop, deterministic) ++ ++ @staticmethod ++ def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, layer_id): ++ return TransformerEngineHelper.get_helper().set_layer_params_to_stack_transformer( ++ stacked_transformer_obj, layer_p, layer_id) ++ ++ @staticmethod ++ def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): ++ return TransformerEngineHelper.get_helper().get_bld_mapping_for_pipelined_transformer( ++ xformer_layer_p) ++ ++ +diff --git a/praxis/layers/transformers.py b/praxis/layers/transformers.py +index ab6cff3..c79dac9 100644 +--- a/praxis/layers/transformers.py ++++ b/praxis/layers/transformers.py +@@ -39,6 +39,7 @@ from praxis.layers import pipeline + from praxis.layers import repeats + from praxis.layers import stats + from praxis.layers import stochastics ++from praxis.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper + + NestedMap = py_utils.NestedMap + WeightInit = base_layer.WeightInit +@@ -1655,23 +1656,8 @@ class StackedTransformer(base_layer.BaseLayer): + p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii]) + else: + p_i = self._clone_layer_params(self.transformer_layer_params_tpl) +- p_i.name = f'layer_{i}' +- p_i.use_cross_attention = self.use_cross_attention +- p_i.num_heads = self.num_heads +- p_i.dim_per_head = self.dim_per_head +- p_i.input_dims = self.model_dims +- p_i.packed_input = self.packed_input +- p_i.atten_dropout_prob = self.atten_dropout_prob or self.dropout_prob +- p_i.residual_dropout_prob = ( +- self.residual_dropout_prob or self.dropout_prob +- ) +- p_i.relu_dropout_prob = self.relu_dropout_prob or self.dropout_prob +- p_i.hidden_dims = self.hidden_dims + +- if self.residual_droppath_prob > 0.0: +- p_i.residual_droppath_prob = ( +- self.residual_droppath_prob * i / max(1, self.num_layers) +- ) ++ p_i = TransformerEngineHelper.set_layer_params_to_stack_transformer(self, p_i, i) + + if self.moe_layers and i in self.moe_layers: + assert self.num_experts > 0 +@@ -1790,6 +1776,8 @@ class StackedTransformer(base_layer.BaseLayer): + ) + return x_out + ++ _fprop = TransformerEngineHelper.get_fprop_caller_of_stack_transformer(_fprop, self.do_eval) ++ + fprop = _fprop + if self.remat: + fprop = nn.remat( +@@ -2255,7 +2243,7 @@ class PipelinedTransformer(base_layer.BaseLayer): + else: + assert self.pipeline_stage.cls == StackedTransformerRepeated + xformer_layer_p = self.pipeline_stage.block.transformer_layer_params_tpl +- bld_mapping = xformer_layer_p.tr_atten_tpl.activation_split_dims_mapping.bld ++ bld_mapping = TransformerEngineHelper.get_bld_mapping_for_pipelined_transformer(xformer_layer_p) + if not self.stream_io: + # Annotate the inputs before the pipeline to prevent unexpected + # propagation from earlier layers. +-- +2.25.1 + + +From ff1745796009cf1ec59f463f8e776c66f1286938 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Fri, 17 Nov 2023 15:21:06 +0800 +Subject: [PATCH 2/2] Fix missing vars wiht PP. + +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 34 ++++++++++++--------- + praxis/layers/pipeline.py | 7 +++-- + 2 files changed, 25 insertions(+), 16 deletions(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index 2d5277e..050d441 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -5,18 +5,25 @@ from praxis import pax_fiddle + from praxis import pytypes + + try: +- import transformer_engine.jax as te +- import transformer_engine.jax.flax as te_flax +- import transformer_engine.jax.praxis as te_praxis +- _IS_TRANSFORMER_ENGINE_INSTALLED = True +- import praxis.layers.repeats as praxis_repeat +- # This is to make Repeat module correctly generate collections we need. +- praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes +- te.fp8.FP8Helper.FP8_COLLECTION_NAME:0}) ++ import transformer_engine.jax as te ++ import transformer_engine.jax.flax as te_flax ++ import transformer_engine.jax.praxis as te_praxis ++ _IS_TRANSFORMER_ENGINE_INSTALLED = True ++ import praxis.layers.repeats as praxis_repeat ++ # This is to make Repeat module correctly generate collections we need. ++ praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes ++ te.fp8.FP8Helper.FP8_COLLECTION_NAME:0}) ++ TE_PIPELINE_EXTRA_VMAP_VAR_AXES = { ++ base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes ++ te.fp8.FP8Helper.FP8_COLLECTION_NAME:0 ++ } ++ ++ TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST = [te.fp8.FP8Helper.FP8_COLLECTION_NAME] + + except ModuleNotFoundError as e: +- _IS_TRANSFORMER_ENGINE_INSTALLED = False +- ++ _IS_TRANSFORMER_ENGINE_INSTALLED = False ++ TE_PIPELINE_EXTRA_VMAP_VAR_AXES = {} ++ TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST = [] + + LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] + JTensor = pytypes.JTensor +@@ -136,8 +143,9 @@ class TEInstalledHelper(TransformerEngineHelperBase): + @staticmethod + def get_bld_mapping_for_pipelined_transformer(_): + rules = te_flax.extend_logical_axis_rules(tuple()) +- batch_mapping = rules[0] +- hidden_tp_mapping = rules[4] ++ # rules [(batch_axis_name, ('replicat', 'data'))', ...)] ++ batch_mapping = rules[0][1] ++ hidden_tp_mapping = rules[4][1] + # [Batch, Seqlen, Hidden] + bld_mapping = [batch_mapping, None, hidden_tp_mapping] + return bld_mapping +@@ -172,5 +180,3 @@ class TransformerEngineHelper(TransformerEngineHelperBase): + def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): + return TransformerEngineHelper.get_helper().get_bld_mapping_for_pipelined_transformer( + xformer_layer_p) +- +- +diff --git a/praxis/layers/pipeline.py b/praxis/layers/pipeline.py +index e3b2f7c..b31526e 100644 +--- a/praxis/layers/pipeline.py ++++ b/praxis/layers/pipeline.py +@@ -28,6 +28,8 @@ from praxis import pax_fiddle + from praxis import py_utils + from praxis import pytypes + from praxis.layers import checkpoint_policy ++from praxis.contrib.gpu.scripts_gpu.te_helper import TE_PIPELINE_EXTRA_VMAP_VAR_AXES ++from praxis.contrib.gpu.scripts_gpu.te_helper import TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST + + NestedMap = py_utils.NestedMap + JTensor = pytypes.JTensor +@@ -414,6 +416,7 @@ class LayerwiseShardablePipelined(base_layer.BaseLayer): + NON_TRAINABLE: 0, + INTERMEDIATES: 0, + HYPER_PARAMS: 0, ++ **TE_PIPELINE_EXTRA_VMAP_VAR_AXES + }, + split_rngs={PARAMS: self.is_initializing(), RANDOM: True}, + metadata_params={ +@@ -798,7 +801,7 @@ class LayerwiseShardablePipelined(base_layer.BaseLayer): + # + # Note that fprop should not use PARAMS rng because there is no var init. + variable_carry = [] +- variable_broadcast = [PARAMS] ++ variable_broadcast = [PARAMS] + TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST + if self.is_mutable_collection(NON_TRAINABLE): + variable_carry.append(NON_TRAINABLE) + else: +@@ -821,7 +824,7 @@ class LayerwiseShardablePipelined(base_layer.BaseLayer): + if bf16_vars_to_convert is not None: + scan_fn = nn.map_variables( + scan_fn, +- mapped_collections=[PARAMS], ++ mapped_collections=[PARAMS, 'fp8_meta_collection'], + mutable=True, + trans_in_fn=_get_to_f32_converter(bf16_vars_to_convert), + trans_out_fn=_get_to_bf16_converter(bf16_vars_to_convert), +-- +2.25.1 + From 5688cac3935363d9ac69e34583cc0c798b92850c Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 1 Dec 2023 17:22:42 -0800 Subject: [PATCH 08/29] apply xla fix reversion from main --- .github/container/Dockerfile.jax | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index 9442bf9c5..85b14ab04 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -36,17 +36,6 @@ RUN --mount=type=ssh \ --mount=type=secret,id=SSH_KNOWN_HOSTS,target=/root/.ssh/known_hosts \ get-source.sh -l xla -m ${MANIFEST_FILE} -# TODO: This is a WAR to NCCL errors we observe in TOT. Should be removed when no longer needed -RUN < Date: Fri, 1 Dec 2023 17:58:26 -0800 Subject: [PATCH 09/29] style --- .github/container/bump.sh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/container/bump.sh b/.github/container/bump.sh index 439f7fbb8..c325e6bc9 100755 --- a/.github/container/bump.sh +++ b/.github/container/bump.sh @@ -55,16 +55,14 @@ set -eou pipefail SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -MANIFEST_IN=${MANIFEST_IN:-} -MANIFEST_OUT=${MANIFEST_OUT:-} ONLY_BUMP_PATCHES=${ONLY_BUMP_PATCHES:-0} -if [[ -z "$MANIFEST_IN" ]]; then +if [[ -z "${MANIFEST_IN:-}" ]]; then echo "Need to provide a value for -i/--input-manifest" usage 1 fi -if [[ -z "$MANIFEST_OUT" ]]; then +if [[ -z "${MANIFEST_OUT:-}" ]]; then # Perform the update in place MANIFEST_OUT=$MANIFEST_IN else From 27036f7c30e098a901845ce1295be1e9ca167aba Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 1 Dec 2023 18:00:47 -0800 Subject: [PATCH 10/29] short -b switch --- .github/container/bump.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/container/bump.sh b/.github/container/bump.sh index c325e6bc9..8f5c85a85 100755 --- a/.github/container/bump.sh +++ b/.github/container/bump.sh @@ -5,16 +5,16 @@ usage() { cat < Date: Fri, 1 Dec 2023 18:01:31 -0800 Subject: [PATCH 11/29] remove [optional] --- .github/container/bump.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/container/bump.sh b/.github/container/bump.sh index 8f5c85a85..aaba8a0c2 100755 --- a/.github/container/bump.sh +++ b/.github/container/bump.sh @@ -5,10 +5,10 @@ usage() { cat < Date: Fri, 1 Dec 2023 18:06:20 -0800 Subject: [PATCH 12/29] submodule init --- .github/container/get-source.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/container/get-source.sh b/.github/container/get-source.sh index afa96d9b7..d7ba18d04 100755 --- a/.github/container/get-source.sh +++ b/.github/container/get-source.sh @@ -91,8 +91,7 @@ set -ex -o pipefail git clone ${GIT_REPO} ${INSTALL_DIR} pushd ${INSTALL_DIR} git checkout ${GIT_REF} -git submodule init -git submodule update --recursive +git submodule update --init --recursive popd echo "Writing to ${OUT_REQUIREMENTS_FILE}:" From c5f0b8706a7c32c86b696c3234d8143f890d5a90 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sun, 3 Dec 2023 21:33:01 -0800 Subject: [PATCH 13/29] .ref -> .latest_verified_commit --- .github/container/bump.sh | 2 +- .github/container/create-distribution.sh | 2 +- .github/container/get-source.sh | 2 +- .github/container/pip-finalize.sh | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/container/bump.sh b/.github/container/bump.sh index aaba8a0c2..bc57439a6 100755 --- a/.github/container/bump.sh +++ b/.github/container/bump.sh @@ -76,7 +76,7 @@ for pkg in $(yq e 'keys | .[]' $MANIFEST_OUT); do url=$(yq e ".${pkg}.url" $MANIFEST_OUT) tracking_ref=$(yq e ".${pkg}.tracking_ref" $MANIFEST_OUT) new_ref=$(git ls-remote $url $tracking_ref | awk '{print $1}') - yq e ".${pkg}.ref = \"$new_ref\"" -i $MANIFEST_OUT + yq e ".${pkg}.latest_verified_commit = \"$new_ref\"" -i $MANIFEST_OUT fi has_patches=$(yq e ".${pkg} | has(\"patches\")" $MANIFEST_OUT) diff --git a/.github/container/create-distribution.sh b/.github/container/create-distribution.sh index 8345fb337..ad5fc0e8b 100755 --- a/.github/container/create-distribution.sh +++ b/.github/container/create-distribution.sh @@ -146,7 +146,7 @@ done git fetch origin $TRACKING_REF # previous-HEAD's purpose is to point to the state of the repo before any distribution changes are made -# We do not rely on the manifest.yaml's .${library}.ref because local commits may be made on top by the upstream docker builds +# We do not rely on the manifest.yaml's .${library}.latest_verified_commit because local commits may be made on top by the upstream docker builds if ! git rev-parse --verify previous-HEAD >/dev/null 2>&1; then echo "[INFO]: Basing distribution on HEAD ($(git rev-parse HEAD)) and marking that with the local branch: previous-HEAD" git branch --force previous-HEAD HEAD diff --git a/.github/container/get-source.sh b/.github/container/get-source.sh index d7ba18d04..800d303c7 100755 --- a/.github/container/get-source.sh +++ b/.github/container/get-source.sh @@ -81,7 +81,7 @@ if [[ "${PACKAGE_MODE}" != "git-clone" ]]; then fi GIT_REPO=$(yq e ".${LIBRARY}.url" $MANIFEST) -GIT_REF=$(yq e ".${LIBRARY}.ref" $MANIFEST) +GIT_REF=$(yq e ".${LIBRARY}.latest_verified_commit" $MANIFEST) INSTALL_DIR=${BASE_INSTALL_DIR}/$LIBRARY echo "Fetching $GIT_REPO#$GIT_REF to $INSTALL_DIR" diff --git a/.github/container/pip-finalize.sh b/.github/container/pip-finalize.sh index 34244c164..3137a480d 100755 --- a/.github/container/pip-finalize.sh +++ b/.github/container/pip-finalize.sh @@ -14,7 +14,7 @@ IFS=$'\n' for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+' || true); do # VCS installs are of the form "PACKAGE @ git+..." PACKAGE=$(echo "$line" | awk '{print $1}') - ref=$(yq e ".${PACKAGE}.ref" ${MANIFEST_FILE}) + ref=$(yq e ".${PACKAGE}.latest_verified_commit" ${MANIFEST_FILE}) echo "${line}@${ref}" done | tee requirements.vcs unset IFS From 645d4af575359615036c8d3804955e34e079da18 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sun, 3 Dec 2023 21:36:11 -0800 Subject: [PATCH 14/29] .ref -> .latest_verified_commit (pt2) --- .github/container/manifest.yaml | 32 ++++++++++----------- rosetta/tests/extra-only-distribution.sh | 2 +- rosetta/tests/mirror-only-distribution.sh | 2 +- rosetta/tests/upstream-only-distribution.sh | 2 +- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 20d1071f9..a19346d41 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -2,31 +2,31 @@ jax: url: https://github.com/google/jax.git tracking_ref: main - ref: 595117b70c11055e569480b80907d8c8a9901805 + latest_verified_commit: 595117b70c11055e569480b80907d8c8a9901805 mode: git-clone xla: url: https://github.com/openxla/xla.git tracking_ref: main - ref: 78a5297d8e4301cb3ba2514061f56f89104e3d88 + latest_verified_commit: 78a5297d8e4301cb3ba2514061f56f89104e3d88 mode: git-clone flax: url: https://github.com/google/flax.git mirror_url: https://github.com/nvjax-svc-0/flax.git tracking_ref: main - ref: 230b0d77e98da22b6e574c3cbff743ca1504bfca + latest_verified_commit: 230b0d77e98da22b6e574c3cbff743ca1504bfca mode: git-clone patches: pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules transformer-engine: url: https://github.com/NVIDIA/TransformerEngine.git tracking_ref: main - ref: 92c1e500dd14608e54f75df8276baa1104c61d48 + latest_verified_commit: 92c1e500dd14608e54f75df8276baa1104c61d48 mode: git-clone t5x: url: https://github.com/google-research/t5x.git mirror_url: https://github.com/nvjax-svc-0/t5x.git tracking_ref: main - ref: 1bfd2f15e5e77b09d60301367f67fdc9bb756b46 + latest_verified_commit: 1bfd2f15e5e77b09d60301367f67fdc9bb756b46 mode: git-clone patches: mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore @@ -36,7 +36,7 @@ paxml: url: https://github.com/google/paxml.git mirror_url: https://github.com/nvjax-svc-0/paxml.git tracking_ref: main - ref: 7ae682d4d99630008e190b96c5296990297175c2 + latest_verified_commit: 7ae682d4d99630008e190b96c5296990297175c2 mode: git-clone patches: pull/46/head: file://patches/paxml/PR-46.patch # adds Transformer Engine support @@ -44,7 +44,7 @@ praxis: url: https://github.com/google/praxis.git mirror_url: https://github.com/nvjax-svc-0/praxis.git tracking_ref: main - ref: b6f32fa0fc6721db1cec75972b0f569c82095956 + latest_verified_commit: b6f32fa0fc6721db1cec75972b0f569c82095956 mode: git-clone patches: pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas. @@ -53,13 +53,13 @@ lingvo: # Used only in ARM pax builds url: https://github.com/tensorflow/lingvo.git tracking_ref: master - ref: 0274fa20b4ff194c1c118b94b5f778caa5d9a84a + latest_verified_commit: 0274fa20b4ff194c1c118b94b5f778caa5d9a84a mode: git-clone tensorflow-text: # Used only in ARM pax builds url: https://github.com/tensorflow/text.git tracking_ref: v2.13.0 - ref: 917a681d7220ebf9b62a08b6f9ce7b7db886ddef + latest_verified_commit: 917a681d7220ebf9b62a08b6f9ce7b7db886ddef mode: git-clone pydantic: version: X.Y.Z @@ -68,36 +68,36 @@ pydantic: fiddle: url: https://github.com/google/fiddle.git tracking_ref: main - ref: d409cf95164599a88e49d2b6a23a0972a7170b0b + latest_verified_commit: d409cf95164599a88e49d2b6a23a0972a7170b0b mode: pip-vcs # Used by t5x airio: url: https://github.com/google/airio.git tracking_ref: main - ref: 69b3ec4ded478ad9cacdc97652a9d086a6a644c4 + latest_verified_commit: 69b3ec4ded478ad9cacdc97652a9d086a6a644c4 mode: pip-vcs clu: url: https://github.com/google/CommonLoopUtils.git tracking_ref: main - ref: 7ba2a9d83a3bc1a97b59482c2f02dc4b3614bc31 + latest_verified_commit: 7ba2a9d83a3bc1a97b59482c2f02dc4b3614bc31 mode: pip-vcs dllogger: url: https://github.com/NVIDIA/dllogger.git tracking_ref: master - ref: 0540a43971f4a8a16693a9de9de73c1072020769 + latest_verified_commit: 0540a43971f4a8a16693a9de9de73c1072020769 mode: pip-vcs jestimator: url: https://github.com/google-research/jestimator.git tracking_ref: main - ref: "fa143d93e337ca8ab77c4510baf21ae52af24ab2" + latest_verified_commit: "fa143d93e337ca8ab77c4510baf21ae52af24ab2" mode: pip-vcs optax: url: https://github.com/deepmind/optax.git tracking_ref: master - ref: bf987e15eacf6efeb1a1a51b8868c094c3a15f9b + latest_verified_commit: bf987e15eacf6efeb1a1a51b8868c094c3a15f9b mode: pip-vcs seqio: url: https://github.com/google/seqio.git tracking_ref: main - ref: 515d917bf58da4103a2bbf39c3716213c36aff03 + latest_verified_commit: 515d917bf58da4103a2bbf39c3716213c36aff03 mode: pip-vcs diff --git a/rosetta/tests/extra-only-distribution.sh b/rosetta/tests/extra-only-distribution.sh index 73257042d..5746ca50c 100755 --- a/rosetta/tests/extra-only-distribution.sh +++ b/rosetta/tests/extra-only-distribution.sh @@ -34,7 +34,7 @@ t5x: mirror_url: https://github.com/nvjax-svc-0/t5x.git extra_dir: $extra_tmp tracking_ref: main - ref: $DISTRIBUTION_BASE_REF + latest_verified_commit: $DISTRIBUTION_BASE_REF mode: git-clone patches: $EXTRA_PATCH_BRANCH: null diff --git a/rosetta/tests/mirror-only-distribution.sh b/rosetta/tests/mirror-only-distribution.sh index 1726dab61..2a62cd66d 100755 --- a/rosetta/tests/mirror-only-distribution.sh +++ b/rosetta/tests/mirror-only-distribution.sh @@ -20,7 +20,7 @@ t5x: url: https://github.com/google-research/t5x.git mirror_url: https://github.com/nvjax-svc-0/t5x.git tracking_ref: main - ref: $DISTRIBUTION_BASE_REF + latest_verified_commit: $DISTRIBUTION_BASE_REF mode: git-clone patches: mirror/pull/4/head: null diff --git a/rosetta/tests/upstream-only-distribution.sh b/rosetta/tests/upstream-only-distribution.sh index a314534b6..9993cd78f 100755 --- a/rosetta/tests/upstream-only-distribution.sh +++ b/rosetta/tests/upstream-only-distribution.sh @@ -20,7 +20,7 @@ t5x: url: https://github.com/google-research/t5x.git mirror_url: https://github.com/nvjax-svc-0/t5x.git tracking_ref: main - ref: $DISTRIBUTION_BASE_REF + latest_verified_commit: $DISTRIBUTION_BASE_REF mode: git-clone patches: pull/1372/head: null From e780f9b20e51cd59a2fc6c7f3396b1a289eea3c9 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sun, 3 Dec 2023 22:31:48 -0800 Subject: [PATCH 15/29] better description in create-distribution.sh --- .github/container/create-distribution.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/container/create-distribution.sh b/.github/container/create-distribution.sh index ad5fc0e8b..6109b2e1c 100755 --- a/.github/container/create-distribution.sh +++ b/.github/container/create-distribution.sh @@ -4,6 +4,11 @@ usage() { cat < Date: Mon, 4 Dec 2023 15:24:31 -0800 Subject: [PATCH 16/29] typo --- .github/container/create-distribution.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/container/create-distribution.sh b/.github/container/create-distribution.sh index 6109b2e1c..504528b22 100755 --- a/.github/container/create-distribution.sh +++ b/.github/container/create-distribution.sh @@ -4,7 +4,7 @@ usage() { cat < Date: Mon, 4 Dec 2023 22:34:38 -0800 Subject: [PATCH 17/29] remove quotes --- .github/container/manifest.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index a19346d41..7620bb177 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -89,7 +89,7 @@ dllogger: jestimator: url: https://github.com/google-research/jestimator.git tracking_ref: main - latest_verified_commit: "fa143d93e337ca8ab77c4510baf21ae52af24ab2" + latest_verified_commit: fa143d93e337ca8ab77c4510baf21ae52af24ab2 mode: pip-vcs optax: url: https://github.com/deepmind/optax.git From 48a3a096cb62eafba11b625e6a1e94feca70395a Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 4 Dec 2023 22:41:26 -0800 Subject: [PATCH 18/29] EOF & INNEREOF --- .github/container/Dockerfile.pax.arm64 | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/container/Dockerfile.pax.arm64 b/.github/container/Dockerfile.pax.arm64 index 5a45c267b..feea160ed 100644 --- a/.github/container/Dockerfile.pax.arm64 +++ b/.github/container/Dockerfile.pax.arm64 @@ -23,12 +23,12 @@ RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazeli FROM wheel-builder as tftext-builder ARG SRC_PATH_TFTEXT -RUN <<"EOT" bash -exu -o pipefail +RUN <<"EOF" bash -exu -o pipefail pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.13.0 get-source.sh -l tensorflow-text -m ${MANIFEST_FILE} cd ${SRC_PATH_TFTEXT} ./oss_scripts/run_build.sh -EOT +EOF #------------------------------------------------------------------------------ # build lingvo @@ -43,14 +43,14 @@ COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ RUN get-source.sh -l lingvo -m ${MANIFEST_FILE} # build lingvo -RUN <<"EOT" bash -exu -o pipefail +RUN <<"EOF" bash -exu -o pipefail pushd ${SRC_PATH_LINGVO} git fetch origin pull/329/head:pr329 git cherry-pick --allow-empty pr329 # Disable 2 flaky tests here -patch -p1 <<"EOF" +patch -p1 <<"EOFINNER" diff --git a/pip_package/build.sh b/pip_package/build.sh index ef62c432e..659e78956 100755 --- a/pip_package/build.sh @@ -64,7 +64,7 @@ index ef62c432e..659e78956 100755 fi DST_DIR="/tmp/lingvo/dist" -EOF +EOFINNER pip install tensorflow_datasets==4.9.2 auditwheel tensorflow==2.13.0 /opt/tensorflow_text*.whl sed -i 's/tensorflow=/#tensorflow=/' docker/dev.requirements.txt @@ -76,7 +76,7 @@ pip install -r docker/dev.requirements.txt # running the tests entirely by uncommentin the following line. # SKIP_TEST=1 PYTHON_MINOR_VERSION=$(python --version | cut -d ' ' -f 2 | cut -d '.' -f 2) pip_package/build.sh -EOT +EOF ############################################################################### ## Pax for AArch64 @@ -95,7 +95,7 @@ COPY --from=tftext-builder ${SRC_PATH_TFTEXT}/tensorflow_text*.whl /opt/ RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-tools.d/requirements-paxml.in # paxml + praxis -RUN <<"EOT" bash -ex +RUN <<"EOF" bash -ex echo "tensorflow==2.13.0" >> /opt/pip-tools.d/requirements-paxml.in echo "tensorflow_datasets==4.9.2" >> /opt/pip-tools.d/requirements-paxml.in echo "chex==0.1.7" >> /opt/pip-tools.d/requirements-paxml.in @@ -130,7 +130,7 @@ for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do fi popd done -EOT +EOF ADD test-pax.sh /usr/local/bin From e8b857b66edc36c7df62c365b2b9b1006394918f Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 4 Dec 2023 23:00:26 -0800 Subject: [PATCH 19/29] bump.sh documentation --- .github/container/bump.sh | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/container/bump.sh b/.github/container/bump.sh index bc57439a6..359778f7d 100755 --- a/.github/container/bump.sh +++ b/.github/container/bump.sh @@ -4,17 +4,24 @@ usage() { cat < /dev/null && pwd ) -ONLY_BUMP_PATCHES=${ONLY_BUMP_PATCHES:-0} +SKIP_BUMP_REFS=${SKIP_BUMP_REFS:-0} if [[ -z "${MANIFEST_IN:-}" ]]; then echo "Need to provide a value for -i/--input-manifest" @@ -72,7 +79,7 @@ fi for pkg in $(yq e 'keys | .[]' $MANIFEST_OUT); do mode=$(yq e ".${pkg}.mode" $MANIFEST_OUT) - if [[ $mode == git-clone || $mode == pip-vcs ]] && [[ $ONLY_BUMP_PATCHES -eq 0 ]]; then + if [[ $mode == git-clone || $mode == pip-vcs ]] && [[ $SKIP_BUMP_REFS -eq 0 ]]; then url=$(yq e ".${pkg}.url" $MANIFEST_OUT) tracking_ref=$(yq e ".${pkg}.tracking_ref" $MANIFEST_OUT) new_ref=$(git ls-remote $url $tracking_ref | awk '{print $1}') From 68b001a827cde893ae340c8e53d68bd5a096d98f Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 4 Dec 2023 23:02:33 -0800 Subject: [PATCH 20/29] rm [Optional] --- .github/container/create-distribution.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/container/create-distribution.sh b/.github/container/create-distribution.sh index 504528b22..0d424a927 100755 --- a/.github/container/create-distribution.sh +++ b/.github/container/create-distribution.sh @@ -10,12 +10,12 @@ or fixes from the patches. This script does not build or install the library, bu that includes all of the patches. Usage: $0 [OPTION]... - -c, --clean [Optional] If set, will clean the patch dir. Default is not to clean - -h, --help [Optional] Print usage. + -c, --clean If set, will clean the patch dir. Default is not to clean + -h, --help Print usage. -m, --manifest=PATH Path to the manifest. Updates it in-place - -o, --override_dir=PATH [Optional] Use this if there is a custom location of the upstream clone. If not specified, uses /opt/\${PACKAGE} + -o, --override_dir=PATH Use this if there is a custom location of the upstream clone. If not specified, uses /opt/\${PACKAGE} -p, --package=KEY The package name in the manifest to use, e.g., t5x, paxml - -s, --skip-apply [Optional] If provided, will only create patches, update manifest, and skip applying. When not provided, applies local patches. + -s, --skip-apply If provided, will only create patches, update manifest, and skip applying. When not provided, applies local patches. -------------- From 7281f59844762eb9708b5fca7be87c950fa392f8 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 4 Dec 2023 23:07:00 -0800 Subject: [PATCH 21/29] get-source.sh documentation --- .github/container/get-source.sh | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/.github/container/get-source.sh b/.github/container/get-source.sh index 800d303c7..86ebcb7ef 100755 --- a/.github/container/get-source.sh +++ b/.github/container/get-source.sh @@ -1,20 +1,24 @@ #!/bin/bash -## Clone a git repo and write the pip-compile input to stdout -## Example: -## get-source.sh -m manifest.yaml -l flax -## Output: -## -e /opt/flax ## Parse command-line arguments usage() { - echo "Usage: $0 [OPTION]..." - echo " -b, --base-dir DIR Directory to install package under. Default /opt" - echo " -h, --help Print usage." - echo " -l, --library LIB The library to clone, e.g., jax, flax, t5x" - echo " -m, --manifest FILE The JAX-Toolbox manifest yaml file" - echo " -o, --out-requirements Create a pip manifest file if specified" - echo +cat < Date: Mon, 4 Dec 2023 23:11:08 -0800 Subject: [PATCH 22/29] trial branch description --- .github/workflows/_ci.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index fb8919221..327910fec 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -9,7 +9,9 @@ on: required: true TRIAL_BRANCH: type: string + description: 'Name of (branch|tag|sha) with bumped manifest and patches' required: true + default: 'main' outputs: TAG_BASE: description: "Tags of the base image built" From 2e9fccde0fa39c490da3c1b62bf31e92b95d120b Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 4 Dec 2023 23:11:54 -0800 Subject: [PATCH 23/29] revert sandbox --- .github/workflows/_sandbox.yaml | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/.github/workflows/_sandbox.yaml b/.github/workflows/_sandbox.yaml index 8d1070c9c..37fa6ca68 100644 --- a/.github/workflows/_sandbox.yaml +++ b/.github/workflows/_sandbox.yaml @@ -2,27 +2,10 @@ name: "~Sandbox" on: workflow_dispatch: - workflow_run: - workflows: - - Nightly Dsitribution test - - Nightly JAX unit test - - Nightly Transformer Engine test - - Nightly Pax MGMN performance test # The trial branch is propagated thru this workflow - - Nightly T5X MGMN performance test - - Nightly Rosetta Paxml build and test - - Nightly Rosetta T5x build and test - types: [completed] - branches: [main] - -permissions: - contents: write # to fetch code, and create commits - actions: write # to cancel previous workflows - packages: write # to upload container jobs: sandbox: runs-on: ubuntu-22.04 - if: always() steps: - name: Login to GitHub Container Registry uses: docker/login-action@v2 From a554f01bb3b1d81ee4d82e0741f4fd21cab12552 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 7 Dec 2023 11:00:12 -0800 Subject: [PATCH 24/29] small comment + set -u flag in pip-finalize --- .github/container/pip-finalize.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/container/pip-finalize.sh b/.github/container/pip-finalize.sh index 3137a480d..371764f6a 100755 --- a/.github/container/pip-finalize.sh +++ b/.github/container/pip-finalize.sh @@ -1,6 +1,6 @@ #!/bin/bash -set -exo pipefail +set -eoux pipefail pushd /opt/pip-tools.d @@ -30,6 +30,7 @@ unset IFS # that treats the above as equivalent and prefers the URI wit the SHA JAX_TOOLBOX_VCS_EQUIVALENCY=true pip-compile -o requirements.txt requirements.vcs $(ls requirements-*.in) +# If there are unpinned VCS dependencies, error since these should be included in the manifest unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@' || true) if [[ $(echo -n "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then echo "Unpinned VCS installs found in $(readlink -f requirements.txt):" From 723f91de56c3f7410fd77affd0d7c4196d91d066 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 8 Dec 2023 15:35:40 -0800 Subject: [PATCH 25/29] fix merge conflict but also allow pass thru of base_build to jax build --- .github/workflows/nightly-jax-build.yaml | 7 +++++++ .github/workflows/nightly-rosetta-pax-build.yaml | 4 ++-- .github/workflows/nightly-rosetta-t5x-build-test.yaml | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/nightly-jax-build.yaml b/.github/workflows/nightly-jax-build.yaml index 5f13884f3..63e2ad52f 100644 --- a/.github/workflows/nightly-jax-build.yaml +++ b/.github/workflows/nightly-jax-build.yaml @@ -6,6 +6,11 @@ on: - cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC workflow_dispatch: inputs: + BASE_IMAGE: + type: string + description: 'CUDA base image built by NVIDIA/JAX-Toolbox' + default: 'ghcr.io/nvidia/jax-toolbox:base' + required: false PUBLISH: type: boolean description: Publish dated images and update the 'latest' tag? @@ -83,6 +88,7 @@ jobs: uses: ./.github/workflows/_build_jax.yaml with: ARCHITECTURE: amd64 + BASE_IMAGE: ${{ inputs.BASE_IMAGE }} BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} TRIAL_BRANCH: ${{ needs.bump-world-state.outputs.TRIAL_BRANCH }} secrets: inherit @@ -92,6 +98,7 @@ jobs: uses: ./.github/workflows/_build_jax.yaml with: ARCHITECTURE: arm64 + BASE_IMAGE: ${{ inputs.BASE_IMAGE }} BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} TRIAL_BRANCH: ${{ needs.bump-world-state.outputs.TRIAL_BRANCH }} secrets: inherit diff --git a/.github/workflows/nightly-rosetta-pax-build.yaml b/.github/workflows/nightly-rosetta-pax-build.yaml index 164f62438..7b6d3c46a 100644 --- a/.github/workflows/nightly-rosetta-pax-build.yaml +++ b/.github/workflows/nightly-rosetta-pax-build.yaml @@ -33,11 +33,11 @@ jobs: metadata: runs-on: ubuntu-22.04 outputs: - BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }} + BUILD_DATE: ${{ steps.meta-vars.outputs.BUILD_DATE }} BASE_LIBRARY: ${{ steps.base-metadata.outputs.BASE_LIBRARY }} BASE_IMAGE_AMD64: ${{ steps.base-metadata.outputs.BASE_IMAGE_AMD64 }} BASE_IMAGE_ARM64: ${{ steps.base-metadata.outputs.BASE_IMAGE_ARM64 }} - PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }} + PUBLISH: ${{ steps.base-metadata.outputs.PUBLISH }} steps: - name: Check if the triggering workflow failed id: if-upstream-failed diff --git a/.github/workflows/nightly-rosetta-t5x-build-test.yaml b/.github/workflows/nightly-rosetta-t5x-build-test.yaml index 70adfb1f0..732482c95 100644 --- a/.github/workflows/nightly-rosetta-t5x-build-test.yaml +++ b/.github/workflows/nightly-rosetta-t5x-build-test.yaml @@ -37,7 +37,7 @@ jobs: BASE_LIBRARY: ${{ steps.base-metadata.outputs.BASE_LIBRARY }} BASE_IMAGE_AMD64: ${{ steps.base-metadata.outputs.BASE_IMAGE_AMD64 }} BASE_IMAGE_ARM64: ${{ steps.base-metadata.outputs.BASE_IMAGE_ARM64 }} - PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }} + PUBLISH: ${{ steps.base-metadata.outputs.PUBLISH }} steps: - name: Check if the triggering workflow failed id: if-upstream-failed From c0b683a4954b2112c56fc57db5984de993631c26 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sun, 10 Dec 2023 22:22:34 -0800 Subject: [PATCH 26/29] Fix BASE_IMAGE description and upstream t5x/pax builds now allow base_image from workflow_dispatch --- .github/workflows/nightly-pax-build.yaml | 26 +++++++++++++++++++ .../workflows/nightly-rosetta-pax-build.yaml | 2 +- .../nightly-rosetta-t5x-build-test.yaml | 2 +- .github/workflows/nightly-t5x-build.yaml | 25 ++++++++++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nightly-pax-build.yaml b/.github/workflows/nightly-pax-build.yaml index ee8dc3412..399a8aad6 100644 --- a/.github/workflows/nightly-pax-build.yaml +++ b/.github/workflows/nightly-pax-build.yaml @@ -8,12 +8,20 @@ on: branches: [main] workflow_dispatch: inputs: + BASE_IMAGE: + type: string + description: 'Upstream Jax mealkit image without $arch-mealkit suffix, e.g., (ghcr.io/nvidia/jax-toolbox-internal:6857094059-upstream-jax). Leaving empty implies ghcr.io/nvidia/jax:mealkit' + default: '' + required: false PUBLISH: type: boolean description: Publish nightly images and update the 'latest' tag? default: false required: false +env: + DOCKER_REGISTRY: ghcr.io/nvidia + permissions: contents: read # to fetch code actions: write # to cancel previous workflows @@ -25,6 +33,8 @@ jobs: runs-on: ubuntu-22.04 outputs: PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }} + BASE_IMAGE_AMD64: ${{ steps.base-image.outputs.BASE_IMAGE_AMD64 }} + BASE_IMAGE_ARM64: ${{ steps.base-image.outputs.BASE_IMAGE_ARM64 }} BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }} steps: - name: Check if the triggering workflow failed @@ -58,11 +68,26 @@ jobs: BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d') echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT + - name: Set base image + id: base-image + shell: bash -x -e {0} + run: | + if [[ -z "${{ inputs.BASE_IMAGE }}" ]]; then + BASE_IMAGE_AMD64=${{ env.DOCKER_REGISTRY }}/jax:mealkit + BASE_IMAGE_ARM64=${{ env.DOCKER_REGISTRY }}/jax:mealkit + else + BASE_IMAGE_AMD64=${{ inputs.BASE_IMAGE }}-amd64-mealkit + BASE_IMAGE_ARM64=${{ inputs.BASE_IMAGE }}-arm64-mealkit + fi + echo "BASE_IMAGE_AMD64=${BASE_IMAGE_AMD64}" >> $GITHUB_OUTPUT + echo "BASE_IMAGE_ARM64=${BASE_IMAGE_ARM64}" >> $GITHUB_OUTPUT + amd64: needs: metadata uses: ./.github/workflows/_build_pax.yaml with: ARCHITECTURE: amd64 + BASE_IMAGE: ${{ needs.metadata.outputs.BASE_IMAGE_AMD64 }} BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} secrets: inherit @@ -71,6 +96,7 @@ jobs: uses: ./.github/workflows/_build_pax.yaml with: ARCHITECTURE: arm64 + BASE_IMAGE: ${{ needs.metadata.outputs.BASE_IMAGE_ARM64 }} BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} secrets: inherit diff --git a/.github/workflows/nightly-rosetta-pax-build.yaml b/.github/workflows/nightly-rosetta-pax-build.yaml index 7b6d3c46a..32a7ad370 100644 --- a/.github/workflows/nightly-rosetta-pax-build.yaml +++ b/.github/workflows/nightly-rosetta-pax-build.yaml @@ -10,7 +10,7 @@ on: inputs: BASE_IMAGE: type: string - description: 'PAX image built by NVIDIA/JAX-Toolbox' + description: 'Upstream Pax mealkit image without $arch-mealkit suffix, e.g., (ghcr.io/nvidia/jax-toolbox-internal:6857094059-upstream-pax). Leaving empty implies ghcr.io/nvidia/upstream-pax:mealkit' default: '' required: false PUBLISH: diff --git a/.github/workflows/nightly-rosetta-t5x-build-test.yaml b/.github/workflows/nightly-rosetta-t5x-build-test.yaml index 732482c95..fc7a78ce1 100644 --- a/.github/workflows/nightly-rosetta-t5x-build-test.yaml +++ b/.github/workflows/nightly-rosetta-t5x-build-test.yaml @@ -10,7 +10,7 @@ on: inputs: BASE_IMAGE: type: string - description: 'T5x image built by NVIDIA/JAX-Toolbox' + description: 'Upstream T5x mealkit image without $arch-mealkit suffix, e.g., (ghcr.io/nvidia/jax-toolbox-internal:6857094059-upstream-t5x). Leaving empty implies ghcr.io/nvidia/upstream-t5x:mealkit' default: '' required: false PUBLISH: diff --git a/.github/workflows/nightly-t5x-build.yaml b/.github/workflows/nightly-t5x-build.yaml index 48f62bc43..1679df9e2 100644 --- a/.github/workflows/nightly-t5x-build.yaml +++ b/.github/workflows/nightly-t5x-build.yaml @@ -8,12 +8,20 @@ on: branches: [main] workflow_dispatch: inputs: + BASE_IMAGE: + type: string + description: 'Upstream Jax mealkit image without $arch-mealkit suffix, e.g., (ghcr.io/nvidia/jax-toolbox-internal:6857094059-upstream-jax). Leaving empty implies ghcr.io/nvidia/jax:mealkit' + default: '' + required: false PUBLISH: type: boolean description: Publish dated images and update the 'latest' tag? default: false required: false +env: + DOCKER_REGISTRY: ghcr.io/nvidia + permissions: contents: read # to fetch code actions: write # to cancel previous workflows @@ -25,6 +33,8 @@ jobs: runs-on: ubuntu-22.04 outputs: PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }} + BASE_IMAGE_AMD64: ${{ steps.base-image.outputs.BASE_IMAGE_AMD64 }} + BASE_IMAGE_ARM64: ${{ steps.base-image.outputs.BASE_IMAGE_ARM64 }} BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }} steps: - name: Check if the triggering workflow failed @@ -58,11 +68,26 @@ jobs: BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d') echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT + - name: Set base image + id: base-image + shell: bash -x -e {0} + run: | + if [[ -z "${{ inputs.BASE_IMAGE }}" ]]; then + BASE_IMAGE_AMD64=${{ env.DOCKER_REGISTRY }}/jax:mealkit + BASE_IMAGE_ARM64=${{ env.DOCKER_REGISTRY }}/jax:mealkit + else + BASE_IMAGE_AMD64=${{ inputs.BASE_IMAGE }}-amd64-mealkit + BASE_IMAGE_ARM64=${{ inputs.BASE_IMAGE }}-arm64-mealkit + fi + echo "BASE_IMAGE_AMD64=${BASE_IMAGE_AMD64}" >> $GITHUB_OUTPUT + echo "BASE_IMAGE_ARM64=${BASE_IMAGE_ARM64}" >> $GITHUB_OUTPUT + amd64: needs: metadata uses: ./.github/workflows/_build_t5x.yaml with: ARCHITECTURE: amd64 + BASE_IMAGE: ${{ needs.metadata.outputs.BASE_IMAGE_AMD64 }} BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} secrets: inherit From 0828455c4d9478eba642ab24a8d17140b84c51d9 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sun, 10 Dec 2023 23:22:10 -0800 Subject: [PATCH 27/29] cleanup --- rosetta/Dockerfile.pax | 17 ----------------- rosetta/Dockerfile.t5x | 17 ----------------- 2 files changed, 34 deletions(-) diff --git a/rosetta/Dockerfile.pax b/rosetta/Dockerfile.pax index de9e0d71a..9c65e87a1 100644 --- a/rosetta/Dockerfile.pax +++ b/rosetta/Dockerfile.pax @@ -22,23 +22,6 @@ ARG UPDATE_PATCHES ENV ENABLE_TE=1 RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu -############DELETE -git config --global user.email "${GIT_USER_EMAIL}" -git config --global user.name "${GIT_USER_NAME}" - -if [[ ! -d /opt/pip ]]; then - git clone https://github.com/pypa/pip.git /opt/pip -fi -git -C /opt/pip checkout 23.3.1 -cp /mnt/jax-toolbox/.github/container/pip-vcs-equivalency.patch /opt/pip/ -cd /opt/pip -git apply Date: Tue, 12 Dec 2023 23:18:58 -0800 Subject: [PATCH 28/29] _build_rosetta.yaml BADGE_FILENAME -> BADGE_FILENAME_PREFIX --- .github/workflows/_build_rosetta.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/_build_rosetta.yaml b/.github/workflows/_build_rosetta.yaml index d2d8d4df1..7263fcf15 100644 --- a/.github/workflows/_build_rosetta.yaml +++ b/.github/workflows/_build_rosetta.yaml @@ -26,7 +26,7 @@ on: description: 'Name of the artifact zip file' required: false default: 'artifact-rosetta-build' - BADGE_FILENAME: + BADGE_FILENAME_PREFIX: type: string description: 'Name of the endpoint JSON file for shields.io badge (w/o .json || arch || library)' required: false @@ -53,7 +53,7 @@ jobs: build-rosetta: runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", small] env: - BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME }}-${{ inputs.BASE_LIBRARY }}-${{ inputs.ARCHITECTURE }}.json + BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME_PREFIX }}-${{ inputs.BASE_LIBRARY }}-${{ inputs.ARCHITECTURE }}.json ARTIFACT_NAME_FULL: ${{ inputs.ARTIFACT_NAME}}-${{ inputs.BASE_LIBRARY }}-${{ inputs.ARCHITECTURE }} outputs: DOCKER_TAG_MEALKIT: ${{ steps.mealkit-metadata.outputs.tags }} From 811c9d2e75d214163c9f15a0a0a00b90049268b9 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 12 Dec 2023 23:20:39 -0800 Subject: [PATCH 29/29] Remove trial-branch default 'main' which is confusing --- .github/workflows/_build_jax.yaml | 1 - .github/workflows/_ci.yaml | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/workflows/_build_jax.yaml b/.github/workflows/_build_jax.yaml index 20ebdba8d..c85a5ea16 100644 --- a/.github/workflows/_build_jax.yaml +++ b/.github/workflows/_build_jax.yaml @@ -41,7 +41,6 @@ on: type: string description: 'Name of branch with bumped manifest and patches' required: true - default: 'main' outputs: DOCKER_TAG_MEALKIT: description: "Tags of the 'mealkit' image built" diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 327910fec..bb9badff7 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -11,7 +11,6 @@ on: type: string description: 'Name of (branch|tag|sha) with bumped manifest and patches' required: true - default: 'main' outputs: TAG_BASE: description: "Tags of the base image built"