Skip to content

Commit

Permalink
Option to install upstream PyTorch from nightly wheels (#2386)
Browse files Browse the repository at this point in the history
Adding a new workflow input to select how to install upstream PyTorch:
build from sources (with the corresponding patches applied) or from the
latest nightly wheels from https://download.pytorch.org/whl/nightly/xpu.
The default is to build from source.

Fixes #1913.
  • Loading branch information
pbchekin authored Oct 4, 2024
1 parent cac829d commit 11258f5
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
27 changes: 22 additions & 5 deletions .github/actions/setup-pytorch/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ inputs:
ref:
description: Branch, tag, commit id
default: ""
mode:
description: Source or wheels
default: source
runs:
using: "composite"
steps:
Expand Down Expand Up @@ -71,7 +74,7 @@ runs:
- name: Generate PyTorch cache key
shell: bash
run: |
PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} | sha256sum - | cut -d\ -f1)
PYTORCH_CACHE_KEY=$(echo $PYTHON_VERSION $PYTORCH_COMMIT_ID ${{ hashFiles('scripts/patch-pytorch.sh') }} ${{ inputs.mode }} | sha256sum - | cut -d\ -f1)
echo "PYTORCH_CACHE_KEY=$PYTORCH_CACHE_KEY" | tee -a "$GITHUB_ENV"
- name: Load PyTorch from a cache
Expand All @@ -90,11 +93,12 @@ runs:
with:
repository: ${{ env.PYTORCH_REPO }}
ref: ${{ env.PYTORCH_COMMIT_ID }}
submodules: recursive
# To build PyTorch from source we need all submodules, they are not required for benchmarks
submodules: ${{ inputs.mode == 'source' && 'recursive' || 'false' }}
path: pytorch

- name: Apply additional PR patches
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' }}
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' && inputs.mode == 'source' }}
shell: bash
run: |
cd pytorch
Expand All @@ -108,7 +112,7 @@ runs:
pip install 'numpy<2.0.0'
- name: Build PyTorch
if: ${{ steps.pytorch-cache.outputs.status == 'miss' }}
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.mode == 'source' }}
shell: bash
run: |
source ${{ inputs.oneapi }}/setvars.sh
Expand All @@ -117,11 +121,24 @@ runs:
pip install -r requirements.txt
python setup.py bdist_wheel
- name: Install PyTorch
- name: Install PyTorch (built from source)
if: ${{ inputs.mode == 'source' }}
shell: bash
run: |
source ${{ inputs.oneapi }}/setvars.sh
pip install pytorch/dist/*.whl
- name: Install PyTorch (from wheels)
if: ${{ inputs.mode == 'wheels' }}
shell: bash
run: |
source ${{ inputs.oneapi }}/setvars.sh
pip install torch --index-url https://download.pytorch.org/whl/nightly/xpu
- name: Get PyTorch version
shell: bash
run: |
source ${{ inputs.oneapi }}/setvars.sh
PYTORCH_VERSION="$(python -c 'import torch;print(torch.__version__)')"
echo "PYTORCH_VERSION=$PYTORCH_VERSION" | tee -a "$GITHUB_ENV"
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/build-test-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ on:
description: PyTorch ref, keep empty for default
type: string
default: ""
pytorch_mode:
description: PyTorch mode, source or wheels
type: choice
options:
- source
- wheels
default: source
upload_test_reports:
description: Upload test reports
type: boolean
Expand Down Expand Up @@ -46,6 +53,7 @@ jobs:
device: ${{ inputs.runner_label }}
runner_label: ${{ inputs.runner_label }}
pytorch_ref: ${{ inputs.pytorch_ref }}
pytorch_mode: ${{ inputs.pytorch_mode || 'source' }}
python_version: ${{ matrix.python }}
upload_test_reports: ${{ inputs.upload_test_reports }}
ignore_errors: ${{ inputs.ignore_errors }}
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/build-test-reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ on:
description: PyTorch ref, keep empty for default
type: string
default: ""
pytorch_mode:
description: PyTorch mode, source or wheels
type: string
default: "source"
python_version:
description: Python version
type: string
Expand Down Expand Up @@ -96,6 +100,7 @@ jobs:
with:
repository: pytorch/pytorch
ref: ${{ inputs.pytorch_ref }}
mode: ${{ inputs.pytorch_mode }}

- name: Install pass_rate dependencies
run: |
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ on:
description: PyTorch ref, keep empty for default
type: string
default: ""
pytorch_mode:
description: PyTorch mode, source or wheels
type: choice
options:
- source
- wheels
default: source
upload_test_reports:
description: Upload test reports
type: boolean
Expand Down Expand Up @@ -120,6 +127,7 @@ jobs:
driver_version: ${{ matrix.driver }}
runner_label: ${{ inputs.runner_label }}
pytorch_ref: ${{ inputs.pytorch_ref }}
pytorch_mode: ${{ inputs.pytorch_mode || 'source' }}
python_version: ${{ matrix.python }}
upload_test_reports: ${{ inputs.upload_test_reports || false }}
ignore_errors: ${{ inputs.ignore_errors || false }}
Expand Down

0 comments on commit 11258f5

Please sign in to comment.