diff --git a/.github/actions/env/action.yml b/.github/actions/env/action.yml index a708e11..e901fb5 100644 --- a/.github/actions/env/action.yml +++ b/.github/actions/env/action.yml @@ -25,6 +25,9 @@ inputs: required: False type: strig +env: + WHELL_DEVICE: '+cpu' + runs: using: "composite" steps: @@ -43,6 +46,10 @@ runs: shell: bash run: pip install uv + - if: runner.os == 'macOS' + env: + WHELL_DEVICE: '' + - if: ${{ contains(fromJson('["nightly"]'), inputs.pytorch-version ) }} name: Install PyTorch nightly shell: bash @@ -51,7 +58,7 @@ runs: - if: ${{ contains(fromJson('["nightly"]'), inputs.pytorch-version ) == false}} name: Install pytorch shell: bash - run: uv pip install --system torch==${{ inputs.pytorch-version }}+cpu ${{ inputs.extra-deps }} --extra-index-url https://download.pytorch.org/whl/cpu + run: uv pip install --system torch==${{ inputs.pytorch-version }}${{ env.WHELL_DEVICE }} ${{ inputs.extra-deps }} --extra-index-url https://download.pytorch.org/whl/cpu - if: ${{ contains(fromJson('["1.9.1"]'), inputs.pytorch-version) }} name: Install accelerate for old torchs